Merge branch 'chieribot' of https://github.com/xiaobaicai2019/MaiMBot into chieribot

pull/541/head
XBC_D2O 2025-03-12 22:29:41 +08:00
commit 2236d61aeb
59 changed files with 5932 additions and 2073 deletions

View File

@ -0,0 +1,11 @@
# 请填写以下内容
(删除掉中括号内的空格,并替换为**小写的x**
1. - [ ] `main` 分支 **禁止修改**,请确认本次提交的分支 **不是 `main` 分支**
2. - [ ] 本次更新 **包含破坏性变更**(如数据库结构变更、配置文件修改等)
3. - [ ] 本次更新是否经过测试
4. 请填写破坏性更新的具体内容(如有):
5. 请简要说明本次更新的内容和目的:
# 其他信息
- **关联 Issue**Close #
- **截图/GIF**
- **附加信息**:

6
.gitignore vendored
View File

@ -1,4 +1,5 @@
data/
data1/
mongodb/
NapCat.Framework.Windows.Once/
log/
@ -193,9 +194,8 @@ cython_debug/
# jieba
jieba.cache
# vscode
/.vscode
# .vscode
!.vscode/settings.json
# direnv
/.direnv

48
CLAUDE.md 100644
View File

@ -0,0 +1,48 @@
# MaiMBot 开发指南
## 🛠️ 常用命令
- **运行机器人**: `python run.py``python bot.py`
- **安装依赖**: `pip install --upgrade -r requirements.txt`
- **Docker 部署**: `docker-compose up`
- **代码检查**: `ruff check .`
- **代码格式化**: `ruff format .`
- **内存可视化**: `run_memory_vis.bat``python -m src.plugins.memory_system.draw_memory`
- **推理过程可视化**: `script/run_thingking.bat`
## 🔧 脚本工具
- **运行MongoDB**: `script/run_db.bat` - 在端口27017启动MongoDB
- **Windows完整启动**: `script/run_windows.bat` - 检查Python版本、设置虚拟环境、安装依赖并运行机器人
- **快速启动**: `script/run_maimai.bat` - 设置UTF-8编码并执行"nb run"命令
## 📝 代码风格
- **Python版本**: 3.9+
- **行长度限制**: 88字符
- **命名规范**:
- `snake_case` 用于函数和变量
- `PascalCase` 用于类
- `_prefix` 用于私有成员
- **导入顺序**: 标准库 → 第三方库 → 本地模块
- **异步编程**: 对I/O操作使用async/await
- **日志记录**: 使用loguru进行一致的日志记录
- **错误处理**: 使用带有具体异常的try/except
- **文档**: 为类和公共函数编写docstrings
## 🧩 系统架构
- **框架**: NoneBot2框架与插件架构
- **数据库**: MongoDB持久化存储
- **设计模式**: 工厂模式和单例管理器
- **配置管理**: 使用环境变量和TOML文件
- **内存系统**: 基于图的记忆结构,支持记忆构建、压缩、检索和遗忘
- **情绪系统**: 情绪模拟与概率权重
- **LLM集成**: 支持多个LLM服务提供商(ChatAnywhere, SiliconFlow, DeepSeek)
## ⚙️ 环境配置
- 使用`template.env`作为环境变量模板
- 使用`template/bot_config_template.toml`作为机器人配置模板
- MongoDB配置: 主机、端口、数据库名
- API密钥配置: 各LLM提供商的API密钥

View File

@ -1,6 +1,5 @@
# 麦麦MaiMBot (编辑中)
<div align="center">
![Python Version](https://img.shields.io/badge/Python-3.9+-blue)
@ -18,7 +17,11 @@
- MongoDB 提供数据持久化支持
- NapCat 作为QQ协议端支持
**最新版本: v0.5.***
**最新版本: v0.5.13**
> [!WARNING]
> 注意3月12日的v0.5.13, 该版本更新较大,建议单独开文件夹部署,然后转移/data文件 和数据库数据库可能需要删除messages下的内容不需要删除记忆
<div align="center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
@ -29,29 +32,35 @@
</a>
</div>
> ⚠️ **注意事项**
> [!WARNING]
> - 项目处于活跃开发阶段,代码可能随时更改
> - 文档未完善,有问题可以提交 Issue 或者 Discussion
> - QQ机器人存在被限制风险请自行了解谨慎使用
> - 由于持续迭代可能存在一些已知或未知的bug
> - 由于开发中可能消耗较多token
**交流群**: 766798517 一群人较多,建议加下面的(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
**交流群**: 571780722 另一个群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
**交流群**: 1035228475 另一个群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
## 💬交流群
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 ,建议加下面的(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [二群](https://qm.qq.com/q/RzmCiRtHEW) 571780722 (开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [三群](https://qm.qq.com/q/wlH5eT8OmQ) 1035228475开发和建议相关讨论不一定有空回复会优先写文档和代码
**其他平台版本**
**📚 有热心网友创作的wiki:** https://maimbot.pages.dev/
**😊 其他平台版本**
- (由 [CabLate](https://github.com/cablate) 贡献) [Telegram 与其他平台(未来可能会有)的版本](https://github.com/cablate/MaiMBot/tree/telegram) - [集中讨论串](https://github.com/SengokuCola/MaiMBot/discussions/149)
##
<div align="left">
<h2>📚 文档 ⬇️ 快速开始使用麦麦 ⬇️</h2>
</div>
### 部署方式
- 📦 **Windows 一键傻瓜式部署**:请运行项目根目录中的 ```run.bat```,部署完成后请参照后续配置指南进行配置
- 📦 **Windows 一键傻瓜式部署**:请运行项目根目录中的 `run.bat`,部署完成后请参照后续配置指南进行配置
- [📦 Windows 手动部署指南 ](docs/manual_deploy_windows.md)
@ -62,7 +71,9 @@
- [🐳 Docker部署指南](docs/docker_deploy.md)
### 配置说明
- [🎀 新手配置指南](docs/installation_cute.md) - 通俗易懂的配置教程,适合初次使用的猫娘
- [⚙️ 标准配置指南](docs/installation_standard.md) - 简明专业的配置说明,适合有经验的用户
@ -75,6 +86,7 @@
## 🎯 功能介绍
### 💬 聊天功能
- 支持关键词检索主动发言对消息的话题topic进行识别如果检测到麦麦存储过的话题就会主动进行发言
- 支持bot名字呼唤发言检测到"麦麦"会主动发言,可配置
- 支持多模型,多厂商自定义配置
@ -83,31 +95,33 @@
- 错别字和多条回复功能麦麦可以随机生成错别字会多条发送回复以及对消息进行reply
### 😊 表情包功能
- 支持根据发言内容发送对应情绪的表情包
- 会自动偷群友的表情包
### 📅 日程功能
- 麦麦会自动生成一天的日程,实现更拟人的回复
### 🧠 记忆功能
- 对聊天记录进行概括存储,在需要时调用,待完善
### 📚 知识库功能
- 基于embedding模型的知识库手动放入txt会自动识别写完了暂时禁用
### 👥 关系功能
- 针对每个用户创建"关系"可以对不同用户进行个性化回复目前只有极其简单的好感度WIP
- 针对每个群创建"群印象"可以对不同群进行个性化回复WIP
## 开发计划TODOLIST
规划主线
0.6.0:记忆系统更新
0.7.0: 麦麦RunTime
- 人格功能WIP
- 群氛围功能WIP
- 图片发送转发功能WIP
@ -127,7 +141,6 @@
- 采用截断生成加快麦麦的反应速度
- 改进发送消息的触发
## 设计理念
- **千石可乐说:**
@ -137,13 +150,14 @@
- 如果人类真的需要一个AI来陪伴自己并不是所有人都需要一个完美的能解决所有问题的helpful assistant而是一个会犯错的拥有自己感知和想法的"生命形式"。
- 代码会保持开源和开放但个人希望MaiMbot的运行时数据保持封闭尽量避免以显式命令来对其进行控制和调试.我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器.
## 📌 注意事项
SengokuCola纯编程外行面向cursor编程很多代码史一样多多包涵
> ⚠️ **警告**:本应用生成内容来自人工智能模型,由 AI 生成请仔细甄别请勿用于违反法律的用途AI生成内容不代表本人观点和立场。
SengokuCola~~纯编程外行面向cursor编程很多代码写得不好多多包涵~~已得到大脑升级
> [!WARNING]
> 本应用生成内容来自人工智能模型,由 AI 生成请仔细甄别请勿用于违反法律的用途AI生成内容不代表本人观点和立场。
## 致谢
[nonebot2](https://github.com/nonebot/nonebot2): 跨平台 Python 异步聊天机器人框架
[NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现
@ -155,6 +169,6 @@ SengokuCola纯编程外行面向cursor编程很多代码史一样多多包
<img src="https://contrib.rocks/image?repo=SengokuCola/MaiMBot" />
</a>
## Stargazers over time
[![Stargazers over time](https://starchart.cc/SengokuCola/MaiMBot.svg?variant=adaptive)](https://starchart.cc/SengokuCola/MaiMBot)

91
bot.py
View File

@ -1,9 +1,12 @@
import asyncio
import os
import shutil
import sys
import nonebot
import time
import uvicorn
from dotenv import load_dotenv
from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter
@ -12,6 +15,8 @@ import platform
# 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ}
uvicorn_server = None
def easter_egg():
# 彩蛋
@ -58,7 +63,7 @@ def init_env():
# 首先加载基础环境变量.env
if os.path.exists(".env"):
load_dotenv(".env")
load_dotenv(".env",override=True)
logger.success("成功加载基础环境变量配置")
@ -94,14 +99,26 @@ def load_env():
def load_logger():
logger.remove() # 移除默认配置
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别默认为INFO
)
if os.getenv("ENVIRONMENT") == "dev":
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别默认为DEBUG
)
else:
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "INFO"), # 根据环境设置日志级别默认为INFO
filter=lambda record: "nonebot" not in record["name"]
)
def scan_provider(env_config: dict):
@ -138,7 +155,39 @@ def scan_provider(env_config: dict):
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
if __name__ == "__main__":
async def graceful_shutdown():
try:
global uvicorn_server
if uvicorn_server:
uvicorn_server.force_exit = True # 强制退出
await uvicorn_server.shutdown()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
logger.error(f"麦麦关闭失败: {e}")
async def uvicorn_main():
global uvicorn_server
config = uvicorn.Config(
app="__main__:app",
host=os.getenv("HOST", "127.0.0.1"),
port=int(os.getenv("PORT", 8080)),
reload=os.getenv("ENVIRONMENT") == "dev",
timeout_graceful_shutdown=5,
log_config=None,
access_log=False
)
server = uvicorn.Server(config)
uvicorn_server = server
await server.serve()
def raw_main():
# 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
if platform.system().lower() != 'windows':
@ -165,10 +214,30 @@ if __name__ == "__main__":
nonebot.init(**base_config, **env_config)
# 注册适配器
global driver
driver = nonebot.get_driver()
driver.register_adapter(Adapter)
# 加载插件
nonebot.load_plugins("src/plugins")
nonebot.run()
if __name__ == "__main__":
try:
raw_main()
global app
app = nonebot.get_asgi()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(uvicorn_main())
except KeyboardInterrupt:
logger.warning("麦麦会努力做的更好的!正在停止中......")
except Exception as e:
logger.error(f"主程序异常: {e}")
finally:
loop.run_until_complete(graceful_shutdown())
loop.close()
logger.info("进程终止完毕,麦麦开始休眠......下次再见哦!")

View File

@ -1,6 +1,84 @@
# Changelog
## [0.5.12] - 2025-3-9
### Added
- 新增了 我是测试
## [0.5.13] - 2025-3-12
AI总结
### 🌟 核心功能增强
#### 记忆系统升级
- 新增了记忆系统的时间戳功能,包括创建时间和最后修改时间
- 新增了记忆图节点和边的时间追踪功能
- 新增了自动补充缺失时间字段的功能
- 新增了记忆遗忘机制,基于时间条件自动遗忘旧记忆
- 优化了记忆系统的数据同步机制
- 优化了记忆系统的数据结构,确保所有数据类型的一致性
#### 私聊功能完善
- 新增了完整的私聊功能支持,包括消息处理和回复
- 新增了聊天流管理器,支持群聊和私聊的上下文管理
- 新增了私聊过滤开关功能
- 优化了关系管理系统,支持跨平台用户关系
#### 消息处理升级
- 新增了消息队列管理系统,支持按时间顺序处理消息
- 新增了消息发送控制器,实现人性化的发送速度和间隔
- 新增了JSON格式分享卡片读取支持
- 新增了Base64格式表情包CQ码支持
- 改进了消息处理流程,支持多种消息类型
### 💻 系统架构优化
#### 配置系统改进
- 新增了配置文件自动更新和版本检测功能
- 新增了配置文件热重载API接口
- 新增了配置文件版本兼容性检查
- 新增了根据不同环境(dev/prod)显示不同级别的日志功能
- 优化了配置文件格式和结构
#### 部署支持扩展
- 新增了Linux系统部署指南
- 新增了Docker部署支持的详细文档
- 新增了NixOS环境支持使用venv方式
- 新增了优雅的shutdown机制
- 优化了Docker部署文档
### 🛠️ 开发体验提升
#### 工具链升级
- 新增了ruff代码格式化和检查工具
- 新增了知识库一键启动脚本
- 新增了自动保存脚本,定期保存聊天记录和关系数据
- 新增了表情包自动获取脚本
- 优化了日志记录使用logger.debug替代print
- 精简了日志输出禁用了Uvicorn/NoneBot默认日志
#### 安全性强化
- 新增了API密钥安全管理机制
- 新增了数据库完整性检查功能
- 新增了表情包文件完整性自动检查
- 新增了异常处理和自动恢复机制
- 优化了安全性检查机制
### 🐛 关键问题修复
#### 系统稳定性
- 修复了systemctl强制停止的问题
- 修复了ENVIRONMENT变量在同一终端下不能被覆盖的问题
- 修复了libc++.so依赖问题
- 修复了数据库索引创建失败的问题
- 修复了MongoDB连接配置相关问题
- 修复了消息队列溢出问题
- 修复了配置文件加载时的版本兼容性问题
#### 功能完善性
- 修复了私聊时产生reply消息的bug
- 修复了回复消息无法识别的问题
- 修复了CQ码解析错误
- 修复了情绪管理器导入问题
- 修复了小名无效的问题
- 修复了表情包发送时的参数缺失问题
- 修复了表情包重复注册问题
- 修复了变量拼写错误问题
### 主要改进方向
1. 提升记忆系统的智能性和可靠性
2. 完善私聊功能的完整生态
3. 优化系统架构和部署便利性
4. 提升开发体验和代码质量
5. 加强系统安全性和稳定性

View File

@ -1,6 +1,12 @@
# Changelog
## [0.0.5] - 2025-3-11
### Added
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
## [0.0.4] - 2025-3-9
### Added
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。

View File

@ -0,0 +1,59 @@
import os
import shutil
import tomlkit
from pathlib import Path
def update_config():
# 获取根目录路径
root_dir = Path(__file__).parent.parent
template_dir = root_dir / "template"
config_dir = root_dir / "config"
# 定义文件路径
template_path = template_dir / "bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml"
new_config_path = config_dir / "bot_config.toml"
# 读取旧配置文件
old_config = {}
if old_config_path.exists():
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 删除旧的配置文件
if old_config_path.exists():
os.remove(old_config_path)
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
# 读取新配置文件
with open(new_config_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 递归更新配置
def update_dict(target, source):
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
update_dict(target[key], value)
else:
try:
# 直接使用tomlkit的item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
if __name__ == "__main__":
update_config()

View File

@ -6,8 +6,6 @@ services:
- NAPCAT_UID=${NAPCAT_UID}
- NAPCAT_GID=${NAPCAT_GID} # 让 NapCat 获取当前用户 GID,UID防止权限问题
ports:
- 3000:3000
- 3001:3001
- 6099:6099
restart: unless-stopped
volumes:
@ -19,7 +17,7 @@ services:
mongodb:
container_name: mongodb
environment:
- tz=Asia/Shanghai
- TZ=Asia/Shanghai
# - MONGO_INITDB_ROOT_USERNAME=your_username
# - MONGO_INITDB_ROOT_PASSWORD=your_password
expose:

20
docs/Jonathan R.md 100644
View File

@ -0,0 +1,20 @@
Jonathan R. Wolpaw 在 “Memory in neuroscience: rhetoric versus reality.” 一文中提到,从神经科学的感觉运动假设出发,整个神经系统的功能是将经验与适当的行为联系起来,而不是单纯的信息存储。
Jonathan R,Wolpaw. (2019). Memory in neuroscience: rhetoric versus reality.. Behavioral and cognitive neuroscience reviews(2).
1. **单一过程理论**
- 单一过程理论认为,识别记忆主要是基于熟悉性这一单一因素的影响。熟悉性是指对刺激的一种自动的、无意识的感知,它可以使我们在没有回忆起具体细节的情况下,判断一个刺激是否曾经出现过。
- 例如在一些实验中研究者发现被试可以在没有回忆起具体学习情境的情况下对曾经出现过的刺激做出正确的判断这被认为是熟悉性在起作用1。
2. **双重过程理论**
- 双重过程理论则认为,识别记忆是基于两个过程:回忆和熟悉性。回忆是指对过去经验的有意识的回忆,它可以使我们回忆起具体的细节和情境;熟悉性则是一种自动的、无意识的感知。
- 该理论认为,在识别记忆中,回忆和熟悉性共同作用,使我们能够判断一个刺激是否曾经出现过。例如,在 “记得 / 知道” 范式中被试被要求判断他们对一个刺激的记忆是基于回忆还是熟悉性。研究发现被试可以区分这两种不同的记忆过程这为双重过程理论提供了支持1。
1. **神经元节点与连接**:借鉴神经网络原理,将每个记忆单元视为一个神经元节点。节点之间通过连接相互关联,连接的强度代表记忆之间的关联程度。在形态学联想记忆中,具有相似形态特征的记忆节点连接强度较高。例如,苹果和橘子的记忆节点,由于在形状、都是水果等形态语义特征上相似,它们之间的连接强度大于苹果与汽车记忆节点间的连接强度。
2. **记忆聚类与层次结构**:依据形态特征的相似性对记忆进行聚类,形成不同的记忆簇。每个记忆簇内部的记忆具有较高的相似性,而不同记忆簇之间的记忆相似性较低。同时,构建记忆的层次结构,高层次的记忆节点代表更抽象、概括的概念,低层次的记忆节点对应具体的实例。比如,“水果” 作为高层次记忆节点,连接着 “苹果”“橘子”“香蕉” 等低层次具体水果的记忆节点。
3. **网络的动态更新**:随着新记忆的不断加入,记忆网络动态调整。新记忆节点根据其形态特征与现有网络中的节点建立连接,同时影响相关连接的强度。若新记忆与某个记忆簇的特征高度相似,则被纳入该记忆簇;若具有独特特征,则可能引发新的记忆簇的形成。例如,当系统学习到一种新的水果 “番石榴”,它会根据番石榴的形态、语义等特征,在记忆网络中找到与之最相似的区域(如水果记忆簇),并建立相应连接,同时调整周围节点连接强度以适应这一新记忆。
- **相似性联想**:该理论认为,当两个或多个事物在形态上具有相似性时,它们在记忆中会形成关联。例如,梨和苹果在形状和都是水果这一属性上有相似性,所以当我们看到梨时,很容易通过形态学联想记忆联想到苹果。这种相似性联想有助于我们对新事物进行分类和理解,当遇到一个新的类似水果时,我们可以通过与已有的水果记忆进行相似性匹配,来推测它的一些特征。
- **时空关联性联想**除了相似性联想MAM 还强调时空关联性联想。如果两个事物在时间或空间上经常同时出现,它们也会在记忆中形成关联。比如,每次在公园里看到花的时候,都能听到鸟儿的叫声,那么花和鸟儿叫声的形态特征(花的视觉形态和鸟叫的听觉形态)就会在记忆中形成关联,以后听到鸟叫可能就会联想到公园里的花。

View File

@ -1,6 +1,7 @@
# 📂 文件及功能介绍 (2025年更新)
## 根目录
- **README.md**: 项目的概述和使用说明。
- **requirements.txt**: 项目所需的Python依赖包列表。
- **bot.py**: 主启动文件负责环境配置加载和NoneBot初始化。
@ -10,6 +11,7 @@
- **run_*.bat**: 各种启动脚本包括数据库、maimai和thinking功能。
## `src/` 目录结构
- **`plugins/` 目录**: 存放不同功能模块的插件。
- **chat/**: 处理聊天相关的功能,如消息发送和接收。
- **memory_system/**: 处理机器人的记忆功能。
@ -22,9 +24,10 @@
- **`common/` 目录**: 存放通用的工具和库。
- **database.py**: 处理与数据库的交互,负责数据的存储和检索。
- **__init__.py**: 初始化模块。
- ****init**.py**: 初始化模块。
## `config/` 目录
- **bot_config_template.toml**: 机器人配置模板。
- **auto_format.py**: 自动格式化工具。
@ -110,6 +113,7 @@
## 消息处理流程
### 1. 消息接收与预处理
- 通过 `ChatBot.handle_message()` 接收群消息。
- 进行用户和群组的权限检查。
- 更新用户关系信息。
@ -117,12 +121,14 @@
- 对消息进行过滤和敏感词检测。
### 2. 主题识别与决策
- 使用 `topic_identifier` 识别消息主题。
- 通过记忆系统检查对主题的兴趣度。
- `willing_manager` 动态计算回复概率。
- 根据概率决定是否回复消息。
### 3. 回复生成与发送
- 如需回复,首先创建 `Message_Thinking` 对象表示思考状态。
- 调用 `ResponseGenerator.generate_response()` 生成回复内容和情感状态。
- 删除思考消息,创建 `MessageSet` 准备发送回复。

View File

@ -1,64 +1,90 @@
# 🐳 Docker 部署指南
## 部署步骤(推荐,但不一定是最新)
## 部署步骤 (推荐,但不一定是最新)
**"更新镜像与容器"部分在本文档 [Part 6](#6-更新镜像与容器)**
### 0. 前提说明
**本文假设读者已具备一定的 Docker 基础知识。若您对 Docker 不熟悉,建议先参考相关教程或文档进行学习,或选择使用 [📦Linux手动部署指南](./manual_deploy_linux.md) 或 [📦Windows手动部署指南](./manual_deploy_windows.md) 。**
### 1. 获取Docker配置文件:
### 1. 获取Docker配置文件
- 建议先单独创建好一个文件夹并进入,作为工作目录
```bash
wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml -O docker-compose.yml
```
- 若需要启用MongoDB数据库的用户名和密码可进入docker-compose.yml取消MongoDB处的注释并修改变量`=`后方的值为你的用户名和密码\
修改后请注意在之后配置`.env.prod`文件时指定MongoDB数据库的用户名密码
- 若需要启用MongoDB数据库的用户名和密码可进入docker-compose.yml取消MongoDB处的注释并修改变量`=` 后方的值为你的用户名和密码\
修改后请注意在之后配置 `.env.prod` 文件时指定MongoDB数据库的用户名密码
### 2. 启动服务
### 2. 启动服务:
- **!!! 请在第一次启动前确保当前工作目录下`.env.prod`与`bot_config.toml`文件存在 !!!**\
- **!!! 请在第一次启动前确保当前工作目录下 `.env.prod``bot_config.toml` 文件存在 !!!**\
由于Docker文件映射行为的特殊性若宿主机的映射路径不存在可能导致意外的目录创建而不会创建文件由于此处需要文件映射到文件需提前确保文件存在且路径正确可使用如下命令:
```bash
touch .env.prod
touch bot_config.toml
```
- 启动Docker容器:
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
# 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker-compose up -d
```
- 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
### 3. 修改配置并重启Docker:
### 3. 修改配置并重启Docker
- 请前往 [🎀 新手配置指南](docs/installation_cute.md) 或 [⚙️ 标准配置指南](docs/installation_standard.md) 完成`.env.prod`与`bot_config.toml`配置文件的编写\
**需要注意`.env.prod`中HOST处IP的填写Docker中部署和系统中直接安装的配置会有所不同**
- 重启Docker容器:
```bash
docker restart maimbot # 若修改过容器名称则替换maimbot为你自定的名
docker restart maimbot # 若修改过容器名称则替换maimbot为你自定的名
```
- 下方命令可以但不推荐只是同时重启NapCat、MongoDB、MaiMBot三个服务
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
# 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker-compose restart
```
- 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
### 4. 登入NapCat管理页添加反向WebSocket
- 在浏览器地址栏输入`http://<宿主机IP>:6099/`进入NapCat的管理Web页添加一个Websocket客户端
- 在浏览器地址栏输入 `http://<宿主机IP>:6099/` 进入NapCat的管理Web页添加一个Websocket客户端
> 网络配置 -> 新建 -> Websocket客户端
- Websocket客户端的名称自定URL栏填入`ws://maimbot:8080/onebot/v11/ws`,启用并保存即可\
- Websocket客户端的名称自定URL栏填入 `ws://maimbot:8080/onebot/v11/ws`,启用并保存即可\
(若修改过容器名称则替换maimbot为你自定的名称)
### 5. 部署完成,愉快地和麦麦对话吧!
### 5. 愉快地和麦麦对话吧!
### 6. 更新镜像与容器
- 拉取最新镜像
```bash
docker-compose pull
```
- 执行启动容器指令,该指令会自动重建镜像有更新的容器并启动
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
# 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker-compose up -d
```
## ⚠️ 注意事项

View File

@ -1,8 +1,9 @@
# 🔧 配置指南 喵~
## 👋 你好呀
## 👋 你好呀
让咱来告诉你我们要做什么喵:
1. 我们要一起设置一个可爱的AI机器人
2. 这个机器人可以在QQ上陪你聊天玩耍哦
3. 需要设置两个文件才能让机器人工作呢
@ -10,16 +11,19 @@
## 📝 需要设置的文件喵
要设置这两个文件才能让机器人跑起来哦:
1. `.env.prod` - 这个文件告诉机器人要用哪些AI服务呢
2. `bot_config.toml` - 这个文件教机器人怎么和你聊天喵
## 🔑 密钥和域名的对应关系
想象一下,你要进入一个游乐园,需要:
1. 知道游乐园的地址(这就是域名 base_url
2. 有入场的门票(这就是密钥 key
`.env.prod` 文件里,我们定义了三个游乐园的地址和门票喵:
```ini
# 硅基流动游乐园
SILICONFLOW_KEY=your_key # 硅基流动的门票
@ -35,6 +39,7 @@ CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere的地
```
然后在 `bot_config.toml` 里,机器人会用这些门票和地址去游乐园玩耍:
```toml
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
@ -47,9 +52,10 @@ base_url = "SILICONFLOW_BASE_URL" # 还是去硅基流动游乐园
key = "SILICONFLOW_KEY" # 用同一张门票就可以啦
```
### 🎪 举个例子喵
### 🎪 举个例子喵
如果你想用DeepSeek官方的服务就要这样改
```toml
[model.llm_reasoning]
name = "deepseek-reasoner" # 改成对应的模型名称这里为DeepseekR1
@ -62,7 +68,8 @@ base_url = "DEEP_SEEK_BASE_URL" # 也去DeepSeek游乐园
key = "DEEP_SEEK_KEY" # 用同一张DeepSeek门票
```
### 🎯 简单来说:
### 🎯 简单来说
- `.env.prod` 文件就像是你的票夹,存放着各个游乐园的门票和地址
- `bot_config.toml` 就是告诉机器人:用哪张票去哪个游乐园玩
- 所有模型都可以用同一个游乐园的票,也可以去不同的游乐园玩耍
@ -88,19 +95,25 @@ CHAT_ANY_WHERE_KEY=your_key
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# 如果你不知道这是什么,那么下面这些不用改,保持原样就好啦
HOST=127.0.0.1 # 如果使用Docker部署需要改成0.0.0.0喵,不然听不见群友讲话了喵
# 如果使用Docker部署需要改成0.0.0.0喵,不然听不见群友讲话了喵
HOST=127.0.0.1
PORT=8080
# 这些是数据库设置,一般也不用改呢
MONGODB_HOST=127.0.0.1 # 如果使用Docker部署需要改成数据库容器的名字喵默认是mongodb喵
# 如果使用Docker部署需要把MONGODB_HOST改成数据库容器的名字喵默认是mongodb喵
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 如果数据库需要用户名,就在这里填写喵
MONGODB_PASSWORD = "" # 如果数据库需要密码,就在这里填写呢
MONGODB_AUTH_SOURCE = "" # 数据库认证源,一般不用改哦
# 数据库认证信息,如果需要认证就取消注释并填写下面三行喵
# MONGODB_USERNAME = ""
# MONGODB_PASSWORD = ""
# MONGODB_AUTH_SOURCE = ""
# 插件设置喵
PLUGINS=["src2.plugins.chat"] # 这里是机器人的插件列表呢
# 也可以使用URI连接数据库取消注释填写在下面这行喵URI的优先级比上面的高
# MONGODB_URI=mongodb://127.0.0.1:27017/MegBot
# 这里是机器人的插件列表呢
PLUGINS=["src2.plugins.chat"]
```
### 第二个文件:机器人配置 (bot_config.toml)
@ -110,7 +123,8 @@ PLUGINS=["src2.plugins.chat"] # 这里是机器人的插件列表呢
```toml
[bot]
qq = "把这里改成你的机器人QQ号喵" # 填写你的机器人QQ号
nickname = "麦麦" # 机器人的名字,你可以改成你喜欢的任何名字哦
nickname = "麦麦" # 机器人的名字你可以改成你喜欢的任何名字哦建议和机器人QQ名称/群昵称一样哦
alias_names = ["小麦", "阿麦"] # 也可以用这个招呼机器人,可以不设置呢
[personality]
# 这里可以设置机器人的性格呢,让它更有趣一些喵
@ -198,10 +212,12 @@ key = "SILICONFLOW_KEY"
- `topic`: 负责理解对话主题的能力呢
## 🌟 小提示
- 如果你刚开始使用,建议保持默认配置呢
- 不同的模型有不同的特长,可以根据需要调整它们的使用比例哦
## 🌟 小贴士喵
- 记得要好好保管密钥key不要告诉别人呢
- 配置文件要小心修改,改错了机器人可能就不能和你玩了喵
- 如果想让机器人更聪明,可以调整 personality 里的设置呢
@ -209,6 +225,7 @@ key = "SILICONFLOW_KEY"
- QQ群号和QQ号都要用数字填写不要加引号哦除了机器人自己的QQ号
## ⚠️ 注意事项
- 这个机器人还在测试中呢,可能会有一些小问题喵
- 如果不知道怎么改某个设置,就保持原样不要动它哦~
- 记得要先有AI服务的密钥不然机器人就不能和你说话了呢

View File

@ -3,14 +3,16 @@
## 简介
本项目需要配置两个主要文件:
1. `.env.prod` - 配置API服务和系统环境
2. `bot_config.toml` - 配置机器人行为和模型
## API配置说明
`.env.prod`和`bot_config.toml`中的API配置关系如下
`.env.prod``bot_config.toml` 中的API配置关系如下
### 在.env.prod中定义API凭证
### 在.env.prod中定义API凭证
```ini
# API凭证配置
SILICONFLOW_KEY=your_key # 硅基流动API密钥
@ -23,7 +25,8 @@ CHAT_ANY_WHERE_KEY=your_key # ChatAnyWhere API密钥
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere API地址
```
### 在bot_config.toml中引用API凭证
### 在bot_config.toml中引用API凭证
```toml
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
@ -32,6 +35,7 @@ key = "SILICONFLOW_KEY" # 引用.env.prod中定义的密钥
```
如需切换到其他API服务只需修改引用
```toml
[model.llm_reasoning]
name = "deepseek-reasoner" # 改成对应的模型名称这里为DeepseekR1
@ -42,6 +46,7 @@ key = "DEEP_SEEK_KEY" # 使用DeepSeek密钥
## 配置文件详解
### 环境配置文件 (.env.prod)
```ini
# API配置
SILICONFLOW_KEY=your_key
@ -52,26 +57,36 @@ CHAT_ANY_WHERE_KEY=your_key
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# 服务配置
HOST=127.0.0.1 # 如果使用Docker部署需要改成0.0.0.0否则QQ消息无法传入
PORT=8080 # 与反向端口相同
# 数据库配置
MONGODB_HOST=127.0.0.1 # 如果使用Docker部署需要改成数据库容器的名字默认是mongodb
MONGODB_PORT=27017 # MongoDB端口
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 数据库用户名
MONGODB_PASSWORD = "" # 数据库密码
MONGODB_AUTH_SOURCE = "" # 认证数据库
# 数据库认证信息,如果需要认证就取消注释并填写下面三行
# MONGODB_USERNAME = ""
# MONGODB_PASSWORD = ""
# MONGODB_AUTH_SOURCE = ""
# 也可以使用URI连接数据库取消注释填写在下面这行URI的优先级比上面的高
# MONGODB_URI=mongodb://127.0.0.1:27017/MegBot
# 插件配置
PLUGINS=["src2.plugins.chat"]
```
### 机器人配置文件 (bot_config.toml)
```toml
[bot]
qq = "机器人QQ号" # 必填
nickname = "麦麦" # 机器人昵称
# alias_names: 配置机器人可使用的别名。当机器人在群聊或对话中被调用时,别名可以作为直接命令或提及机器人的关键字使用。
# 该配置项为字符串数组。例如: ["小麦", "阿麦"]
alias_names = ["小麦", "阿麦"] # 机器人别名
[personality]
prompt_personality = [

View File

@ -0,0 +1,444 @@
# 面向纯新手的Linux服务器麦麦部署指南
## 你得先有一个服务器
为了能使麦麦在你的电脑关机之后还能运行,你需要一台不间断开机的主机,也就是我们常说的服务器。
华为云、阿里云、腾讯云等等都是在国内可以选择的选择。
你可以去租一台最低配置的就足敷需要了,按月租大概十几块钱就能租到了。
我们假设你已经租好了一台Linux架构的云服务器。我用的是阿里云ubuntu24.04,其他的原理相似。
## 0.我们就从零开始吧
### 网络问题
为访问github相关界面推荐去下一款加速器新手可以试试watttoolkit。
### 安装包下载
#### MongoDB
对于ubuntu24.04 x86来说是这个
https://repo.mongodb.org/apt/ubuntu/dists/noble/mongodb-org/8.0/multiverse/binary-amd64/mongodb-org-server_8.0.5_amd64.deb
如果不是就在这里自行选择对应版本
https://www.mongodb.com/try/download/community-kubernetes-operator
#### Napcat
在这里选择对应版本。
https://github.com/NapNeko/NapCatQQ/releases/tag/v4.6.7
对于ubuntu24.04 x86来说是这个
https://dldir1.qq.com/qqfile/qq/QQNT/ee4bd910/linuxqq_3.2.16-32793_amd64.deb
#### 麦麦
https://github.com/SengokuCola/MaiMBot/archive/refs/tags/0.5.8-alpha.zip
下载这个官方压缩包。
### 路径
我把麦麦相关文件放在了/moi/mai里面你可以凭喜好更改记得适当调整下面涉及到的部分即可。
文件结构:
```
moi
└─ mai
├─ linuxqq_3.2.16-32793_amd64.deb
├─ mongodb-org-server_8.0.5_amd64.deb
└─ bot
└─ MaiMBot-0.5.8-alpha.zip
```
### 网络
你可以在你的服务器控制台网页更改防火墙规则允许6099808027017这几个端口的出入。
## 1.正式开始!
远程连接你的服务器你会看到一个黑框框闪着白方格这就是我们要进行设置的场所——终端了。以下的bash命令都是在这里输入。
## 2. Python的安装
- 导入 Python 的稳定版 PPA
```bash
sudo add-apt-repository ppa:deadsnakes/ppa
```
- 导入 PPA 后,更新 APT 缓存:
```bash
sudo apt update
```
- 在「终端」中执行以下命令来安装 Python 3.12
```bash
sudo apt install python3.12
```
- 验证安装是否成功:
```bash
python3.12 --version
```
- 在「终端」中,执行以下命令安装 pip
```bash
sudo apt install python3-pip
```
- 检查Pip是否安装成功
```bash
pip --version
```
- 安装必要组件
``` bash
sudo apt install python-is-python3
```
## 3.MongoDB的安装
``` bash
cd /moi/mai
```
``` bash
dpkg -i mongodb-org-server_8.0.5_amd64.deb
```
``` bash
mkdir -p /root/data/mongodb/{data,log}
```
## 4.MongoDB的运行
```bash
service mongod start
```
```bash
systemctl status mongod #通过这条指令检查运行状态
```
有需要的话可以把这个服务注册成开机自启
```bash
sudo systemctl enable mongod
```
## 5.napcat的安装
``` bash
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && sudo bash napcat.sh
```
上面的不行试试下面的
``` bash
dpkg -i linuxqq_3.2.16-32793_amd64.deb
apt-get install -f
dpkg -i linuxqq_3.2.16-32793_amd64.deb
```
成功的标志是输入``` napcat ```出来炫酷的彩虹色界面
## 6.napcat的运行
此时你就可以根据提示在```napcat```里面登录你的QQ号了。
```bash
napcat start <你的QQ号>
napcat status #检查运行状态
```
然后你就可以登录napcat的webui进行设置了
```http://<你服务器的公网IP>:6099/webui?token=napcat```
第一次是这个后续改了密码之后token就会对应修改。你也可以使用```napcat log <你的QQ号>```来查看webui地址。把里面的```127.0.0.1```改成<你服务器的公网IP>即可。
登录上之后在网络配置界面添加websocket客户端名称随便输一个url改成`ws://127.0.0.1:8080/onebot/v11/ws`保存之后点启用,就大功告成了。
## 7.麦麦的安装
### step 1 安装解压软件
```
sudo apt-get install unzip
```
### step 2 解压文件
```bash
cd /moi/mai/bot # 注意:要切换到压缩包的目录中去
unzip MaiMBot-0.5.8-alpha.zip
```
### step 3 进入虚拟环境安装库
```bash
cd /moi/mai/bot
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```
### step 4 试运行
```bash
cd /moi/mai/bot
python -m venv venv
source venv/bin/activate
python bot.py
```
肯定运行不成功,不过你会发现结束之后多了一些文件
```
bot
├─ .env.prod
└─ config
└─ bot_config.toml
```
你要会vim直接在终端里修改也行不过也可以把它们下到本地改好再传上去
### step 5 文件配置
本项目需要配置两个主要文件:
1. `.env.prod` - 配置API服务和系统环境
2. `bot_config.toml` - 配置机器人行为和模型
#### API
你可以注册一个硅基流动的账号通过邀请码注册有14块钱的免费额度https://cloud.siliconflow.cn/i/7Yld7cfg。
#### 在.env.prod中定义API凭证
```
# API凭证配置
SILICONFLOW_KEY=your_key # 硅基流动API密钥
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ # 硅基流动API地址
DEEP_SEEK_KEY=your_key # DeepSeek API密钥
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 # DeepSeek API地址
CHAT_ANY_WHERE_KEY=your_key # ChatAnyWhere API密钥
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere API地址
```
#### 在bot_config.toml中引用API凭证
```
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL" # 引用.env.prod中定义的地址
key = "SILICONFLOW_KEY" # 引用.env.prod中定义的密钥
```
如需切换到其他API服务只需修改引用
```
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "DEEP_SEEK_BASE_URL" # 切换为DeepSeek服务
key = "DEEP_SEEK_KEY" # 使用DeepSeek密钥
```
#### 配置文件详解
##### 环境配置文件 (.env.prod)
```
# API配置
SILICONFLOW_KEY=your_key
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=your_key
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
CHAT_ANY_WHERE_KEY=your_key
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# 服务配置
HOST=127.0.0.1 # 如果使用Docker部署需要改成0.0.0.0否则QQ消息无法传入
PORT=8080
# 数据库配置
MONGODB_HOST=127.0.0.1 # 如果使用Docker部署需要改成数据库容器的名字默认是mongodb
MONGODB_PORT=27017
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 数据库用户名
MONGODB_PASSWORD = "" # 数据库密码
MONGODB_AUTH_SOURCE = "" # 认证数据库
# 插件配置
PLUGINS=["src2.plugins.chat"]
```
##### 机器人配置文件 (bot_config.toml)
```
[bot]
qq = "机器人QQ号" # 必填
nickname = "麦麦" # 机器人昵称(你希望机器人怎么称呼它自己)
[personality]
prompt_personality = [
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
"是一个女大学生,你有黑色头发,你会刷小红书"
]
prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
[message]
min_text_length = 2 # 最小回复长度
max_context_size = 15 # 上下文记忆条数
emoji_chance = 0.2 # 表情使用概率
ban_words = [] # 禁用词列表
[emoji]
auto_save = true # 自动保存表情
enable_check = false # 启用表情审核
check_prompt = "符合公序良俗"
[groups]
talk_allowed = [] # 允许对话的群号
talk_frequency_down = [] # 降低回复频率的群号
ban_user_id = [] # 禁止回复的用户QQ号
[others]
enable_advance_output = true # 启用详细日志
enable_kuuki_read = true # 启用场景理解
# 模型配置
[model.llm_reasoning] # 推理模型
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_reasoning_minor] # 轻量推理模型
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal] # 对话模型
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal_minor] # 备用对话模型
name = "deepseek-ai/DeepSeek-V2.5"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.vlm] # 图像识别模型
name = "deepseek-ai/deepseek-vl2"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.embedding] # 文本向量模型
name = "BAAI/bge-m3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[topic.llm_topic]
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
```
**step # 6** 运行
现在再运行
```bash
cd /moi/mai/bot
python -m venv venv
source venv/bin/activate
python bot.py
```
应该就能运行成功了。
## 8.事后配置
可是现在还有个问题只要你一关闭终端bot.py就会停止运行。那该怎么办呢我们可以把bot.py注册成服务。
重启服务器打开MongoDB和napcat服务。
新建一个文件,名为`bot.service`,内容如下
```
[Unit]
Description=maimai bot
[Service]
WorkingDirectory=/moi/mai/bot
ExecStart=/moi/mai/bot/venv/bin/python /moi/mai/bot/bot.py
Restart=on-failure
User=root
[Install]
WantedBy=multi-user.target
```
里面的路径视自己的情况更改。
把它放到`/etc/systemd/system`里面。
重新加载 `systemd` 配置:
```bash
sudo systemctl daemon-reload
```
启动服务:
```bash
sudo systemctl start bot.service # 启动服务
sudo systemctl restart bot.service # 或者重启服务
```
检查服务状态:
```bash
sudo systemctl status bot.service
```
现在再关闭终端检查麦麦能不能正常回复QQ信息。如果可以的话就大功告成了
## 9.命令速查
```bash
service mongod start # 启动mongod服务
napcat start <你的QQ号> # 登录napcat
cd /moi/mai/bot # 切换路径
python -m venv venv # 创建虚拟环境
source venv/bin/activate # 激活虚拟环境
sudo systemctl daemon-reload # 重新加载systemd配置
sudo systemctl start bot.service # 启动bot服务
sudo systemctl enable bot.service # 启动bot服务
sudo systemctl status bot.service # 检查bot服务状态
```
```
python bot.py
```

View File

@ -1,6 +1,7 @@
# 📦 Linux系统如何手动部署MaiMbot麦麦
## 准备工作
- 一台联网的Linux设备本教程以Ubuntu/Debian系为例
- QQ小号QQ框架的使用可能导致qq被风控严重小概率可能会导致账号封禁强烈不推荐使用大号
- 可用的大模型API
@ -20,6 +21,7 @@
- 数据库是什么如何安装并启动MongoDB
- 如何运行一个QQ机器人以及NapCat框架是什么
---
## 环境配置
@ -33,7 +35,9 @@ python --version
# 或
python3 --version
```
如果版本低于3.9请更新Python版本。
```bash
# Ubuntu/Debian
sudo apt update
@ -45,6 +49,7 @@ sudo update-alternatives --config python3
```
### 2**创建虚拟环境**
```bash
# 方法1使用venv(推荐)
python3 -m venv maimbot
@ -65,32 +70,37 @@ pip install -r requirements.txt
---
## 数据库配置
### 3**安装并启动MongoDB**
- 安装与启动Debian参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-debian/)Ubuntu参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-ubuntu/)
### 3**安装并启动MongoDB**
- 安装与启动Debian参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-debian/)Ubuntu参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-ubuntu/)
- 默认连接本地27017端口
---
## NapCat配置
### 4**安装NapCat框架**
- 参考[NapCat官方文档](https://www.napcat.wiki/guide/boot/Shell#napcat-installer-linux%E4%B8%80%E9%94%AE%E4%BD%BF%E7%94%A8%E8%84%9A%E6%9C%AC-%E6%94%AF%E6%8C%81ubuntu-20-debian-10-centos9)安装
- 使用QQ小号登录添加反向WS地址
`ws://127.0.0.1:8080/onebot/v11/ws`
- 使用QQ小号登录添加反向WS地址: `ws://127.0.0.1:8080/onebot/v11/ws`
---
## 配置文件设置
### 5**配置文件设置让麦麦Bot正常工作**
- 修改环境配置文件:`.env.prod`
- 修改机器人配置文件:`bot_config.toml`
---
## 启动机器人
### 6**启动麦麦机器人**
```bash
# 在项目目录下操作
nb run
@ -100,16 +110,69 @@ python3 bot.py
---
## **其他组件(可选)**
- 直接运行 knowledge.py生成知识库
### 7**使用systemctl管理maimbot**
使用以下命令添加服务文件:
```bash
sudo nano /etc/systemd/system/maimbot.service
```
输入以下内容:
`<maimbot_directory>`你的maimbot目录
`<venv_directory>`你的venv环境就是上文创建环境后执行的代码`source maimbot/bin/activate`中source后面的路径的绝对路径
```ini
[Unit]
Description=MaiMbot 麦麦
After=network.target mongod.service
[Service]
Type=simple
WorkingDirectory=<maimbot_directory>
ExecStart=<venv_directory>/python3 bot.py
ExecStop=/bin/kill -2 $MAINPID
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
```
输入以下命令重新加载systemd
```bash
sudo systemctl daemon-reload
```
启动并设置开机自启:
```bash
sudo systemctl start maimbot
sudo systemctl enable maimbot
```
输入以下命令查看日志:
```bash
sudo journalctl -xeu maimbot
```
---
## **其他组件(可选)**
- 直接运行 knowledge.py生成知识库
---
## 常见问题
🔧 权限问题:在命令前加`sudo`
🔌 端口占用:使用`sudo lsof -i :8080`查看端口占用
🛡️ 防火墙确保8080/27017端口开放
```bash
sudo ufw allow 8080/tcp
sudo ufw allow 27017/tcp

View File

@ -30,12 +30,13 @@
在创建虚拟环境之前请确保你的电脑上安装了Python 3.9及以上版本。如果没有,可以按以下步骤安装:
1. 访问Python官网下载页面https://www.python.org/downloads/release/python-3913/
1. 访问Python官网下载页面<https://www.python.org/downloads/release/python-3913/>
2. 下载Windows安装程序 (64-bit): `python-3.9.13-amd64.exe`
3. 运行安装程序,并确保勾选"Add Python 3.9 to PATH"选项
4. 点击"Install Now"开始安装
或者使用PowerShell自动下载安装需要管理员权限
```powershell
# 下载并安装Python 3.9.13
$pythonUrl = "https://www.python.org/ftp/python/3.9.13/python-3.9.13-amd64.exe"
@ -46,7 +47,7 @@ Start-Process -Wait -FilePath $pythonInstaller -ArgumentList "/quiet", "InstallA
### 2**创建Python虚拟环境来运行程序**
你可以选择使用以下两种方法之一来创建Python环境
> 你可以选择使用以下两种方法之一来创建Python环境
```bash
# ---方法1使用venvPython自带
@ -60,6 +61,7 @@ maimbot\\Scripts\\activate
# 安装依赖
pip install -r requirements.txt
```
```bash
# ---方法2使用conda
# 创建一个新的conda环境环境名为maimbot
@ -74,27 +76,35 @@ pip install -r requirements.txt
```
### 2**然后你需要启动MongoDB数据库来存储信息**
- 安装并启动MongoDB服务
- 默认连接本地27017端口
### 3**配置NapCat让麦麦bot与qq取得联系**
- 安装并登录NapCat用你的qq小号
- 添加反向WS`ws://127.0.0.1:8080/onebot/v11/ws`
- 添加反向WS: `ws://127.0.0.1:8080/onebot/v11/ws`
### 4**配置文件设置让麦麦Bot正常工作**
- 修改环境配置文件:`.env.prod`
- 修改机器人配置文件:`bot_config.toml`
### 5**启动麦麦机器人**
- 打开命令行cd到对应路径
```bash
nb run
```
- 或者cd到对应路径后
```bash
python bot.py
```
### 6**其他组件(可选)**
- `run_thingking.bat`: 启动可视化推理界面(未完善)
- 直接运行 knowledge.py生成知识库

View File

@ -1,43 +1,21 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1741196730,
"narHash": "sha256-0Sj6ZKjCpQMfWnN0NURqRCQn2ob7YtXTAOTwCuz7fkA=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "48913d8f9127ea6530a2a2f1bd4daa1b8685d8a3",
"type": "github"
"lastModified": 0,
"narHash": "sha256-nJj8f78AYAxl/zqLiFGXn5Im1qjFKU8yBPKoWEeZN5M=",
"path": "/nix/store/f30jn7l0bf7a01qj029fq55i466vmnkh-source",
"type": "path"
},
"original": {
"owner": "NixOS",
"ref": "nixos-24.11",
"repo": "nixpkgs",
"type": "github"
"id": "nixpkgs",
"type": "indirect"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
"nixpkgs": "nixpkgs",
"utils": "utils"
}
},
"systems": {
@ -54,6 +32,24 @@
"repo": "default",
"type": "github"
}
},
"utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
}
},
"root": "root",

View File

@ -1,61 +1,38 @@
{
description = "MaiMBot Nix Dev Env";
# 本配置仅方便用于开发,但是因为 nb-cli 上游打包中并未包含 nonebot2因此目前本配置并不能用于运行和调试
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
flake-utils.url = "github:numtide/flake-utils";
utils.url = "github:numtide/flake-utils";
};
outputs =
{
self,
nixpkgs,
flake-utils,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
};
outputs = {
self,
nixpkgs,
utils,
...
}:
utils.lib.eachDefaultSystem (system: let
pkgs = import nixpkgs {inherit system;};
pythonPackages = pkgs.python3Packages;
in {
devShells.default = pkgs.mkShell {
name = "python-venv";
venvDir = "./.venv";
buildInputs = [
pythonPackages.python
pythonPackages.venvShellHook
pythonPackages.numpy
];
pythonEnv = pkgs.python3.withPackages (
ps: with ps; [
pymongo
python-dotenv
pydantic
jieba
openai
aiohttp
requests
urllib3
numpy
pandas
matplotlib
networkx
python-dateutil
APScheduler
loguru
tomli
customtkinter
colorama
pypinyin
pillow
setuptools
]
);
in
{
devShell = pkgs.mkShell {
buildInputs = [
pythonEnv
pkgs.nb-cli
];
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
pip install -r requirements.txt
'';
shellHook = ''
'';
};
}
);
postShellHook = ''
# allow pip to install wheels
unset SOURCE_DATE_EPOCH
'';
};
});
}

View File

@ -0,0 +1,141 @@
cbb569e - Create 如果你更新了版本,点我.txt
a91ef7b - 自动升级配置文件脚本
ed18f2e - 新增了知识库一键启动漂亮脚本
80ed568 - fix: 删除print调试代码
c681a82 - 修复小名无效问题
e54038f - fix: 从 nixpkgs 增加 numpy 依赖,以避免出现 libc++.so 找不到的问题
26782c9 - fix: 修复 ENVIRONMENT 变量在同一终端下不能被覆盖的问题
8c34637 - 提高健壮性
2688a96 - close SengokuCola/MaiMBot#225 让麦麦可以正确读取分享卡片
cd16e68 - 修复表情包发送时的缺失参数
b362c35 - feat: 更新 flake.nix ,采用 venv 的方式生成环境nixos用户也可以本机运行项目了
3c8c897 - 屏蔽一个臃肿的debug信息
9d0152a - 修复了合并过程中造成的代码重复
956135c - 添加一些注释
a412741 - 将print变为logger.debug
3180426 - 修复了没有改掉的typo字段
aea3bff - 添加私聊过滤开关,更新config,增加约束
cda6281 - chore: update emoji_manager.py
baed856 - 修正了私聊屏蔽词输出
66a0f18 - 修复了私聊时产生reply消息的bug
3bf5cd6 - feat: 新增运行时重载配置文件;新增根据不同环境(dev;prod)显示不同级别的log
33cd83b - 添加私聊功能
aa41f0d - fix: 放反了
ef8691c - fix: 修改message继承逻辑修复回复消息无法识别
7d017be - fix:模型降级
e1019ad - fix: 修复变量拼写错误并优化代码可读性
c24bb70 - fix: 流式输出模式增加结束判断与token用量记录
60a9376 - 添加logger的debug输出开关,默认为不开启
bfa9a3c - fix: 添加群信息获取的错误处理 (#173)
4cc5c8e - 修正.env.prod和.env.dev的生成
dea14c1 - fix: 模型降级目前只对硅基流动的V3和R1生效
b6edbea - fix: 图片保存路径不正确
01a6fa8 - fix: 删除神秘test
20f009d - 修复systemctl强制停止maimbot的问题
af962c2 - 修复了情绪管理器没有正确导入导致发布出消息
0586700 - 按照Sourcery提供的建议修改systemctl管理指南
e48b32a - 在手动部署教程中增加使用systemctl管理
5760412 - fix: 小修
1c9b0cc - fix: 修复部分cq码解析错误merge
b6867b9 - fix: 统一使用os.getenv获取数据库连接信息避免从config对象获取不存在的值时出现KeyError
5e069f7 - 修复记忆保存时无时间信息的bug
73a3e41 - 修复记忆更新bug
52c93ba - refactor: use Base64 for emoji CQ codes
67f6d7c - fix: 保证能运行的小修改
c32c4fb - refactor: 修改配置文件的版本号
a54ca8c - Merge remote-tracking branch 'upstream/debug' into feat_regix
8cbf9bb - feat: 史上最好的消息流重构和图片管理
9e41c4f - feat: 修改 bot_config 0.0.5 版本的变更日志
eede406 - fix: 修复nonebot无法加载项目的问题
00e02ed - fix: 0.0.5 版本的增加分层控制项
0f99d6a - Update docs/docker_deploy.md
c789074 - feat: 增加ruff依赖
ff65ab8 - feat: 修改默认的ruff配置文件同时消除config的所有不符合规范的地方
bf97013 - feat: 精简日志禁用Uvicorn/NoneBot默认日志启动方式改为显示加载uvicorn以便优雅shutdown
d9a2863 - 优化Docker部署文档更新容器部分
efcf00f - Docker部署文档追加更新部分
a63ce96 - fix: 更新情感判断模型配置(使配置文件里的 llm_emotion_judge 生效)
1294c88 - feat: 增加标准化格式化设置
2e8cd47 - fix: 避免可能出现的日程解析错误
043a724 - 修一下文档跳转,小美化(
e4b8865 - 支持别名,可以用不同名称召唤机器人
7b35ddd - ruff 哥又有新点子
7899e67 - feat: 重构完成开始测试debug
354d6d0 - 记忆系统优化
6cef8fd - 修复时区删去napcat用不到的端口
cd96644 - 添加使用说明
84495f8 - fix
204744c - 修改配置名与修改过滤对象为raw_message
a03b490 - Update README.md
2b2b342 - feat: 增加 ruff 依赖
72a6749 - fix: 修复docker部署时区指定问题
ee579bc - Update README.md
1b611ec - resolve SengokuCola/MaiMBot#167 根据正则表达式过滤消息
6e2ea82 - refractor: 几乎写完了,进入测试阶段
2ffdfef - More
e680405 - fix: typo 'discription'
68b3f57 - Minor Doc Update
312f065 - Create linux_deploy_guide_for_beginners.md
ed505a4 - fix: 使用动态路径替换硬编码的项目路径
8ff7bb6 - docs: 更新文档,修正格式并添加必要的换行符
6e36a56 - feat: 增加 MONGODB_URI 的配置项并将所有env文件的注释单独放在一行python的dotenv有时无法正确处理行内注释
4baa6c6 - feat: 实现MongoDB URI方式连接并统一数据库连接代码。
8a32d18 - feat: 优化willing_manager逻辑增加回复保底概率
c9f1244 - docs: 改进README.md文档格式和排版
e1b484a - docs: 添加CLAUDE.md开发指南文件用于Claude Code
a43f949 - fix: remove duplicate message(CR comments)
fddb641 - fix: 修复错误的空值检测逻辑
8b7876c - fix: 修复没有上传tag的问题
6b4130e - feat: 增加stable-dev分支的打包
052e67b - refactor: 日志打印优化(终于改完了,爽了
a7f9d05 - 修复记忆整理传入格式问题
536bb1d - fix: 更新情感判断模型配置
8d99592 - fix: logger初始化顺序
052802c - refactor: logger promotion
8661d94 - doc: README.md - telegram version information
5746afa - refactor: logger in src\plugins\chat\bot.py
288dbb6 - refactor: logger in src\plugins\chat\__init__.py
8428a06 - fix: memory logger optimization (CR comment)
665c459 - 改进了可视化脚本
6c35704 - fix: 调用了错误的函数
3223153 - feat: 一键脚本新增记忆可视化
3149dd3 - fix: mongodb.zip 无法解压 fix:更换执行命令的方法 fix:当 db 不存在时自动创建 feat: 一键安装完成后启动麦麦
089d6a6 - feat: 针对硅基流动的Pro模型添加了自动降级功能
c4b0917 - 一个记忆可视化小脚本
6a71ea4 - 修复了记忆时间bug,config添加了记忆屏蔽关键词
1b5344f - fix: 优化bot初始化的日志&格式
41aa974 - fix: 优化chat/config.py的日志&格式
980cde7 - fix: 优化scheduler_generator日志&格式
31a5514 - fix: 调整全局logger加载顺序
8baef07 - feat: 添加全局logger初始化设置
5566f17 - refractor: 几乎写完了,进入测试阶段
6a66933 - feat: 添加开发环境.env.dev初始化
411ff1a - feat: 安装 MongoDB Compass
0de9eba - feat: 增加实时更新贡献者列表的功能
f327f45 - fix: 优化src/plugins/chat/__init__.py的import
826daa5 - fix: 当虚拟环境存在时跳过创建
f54de42 - fix: time.tzset 仅在类 Unix 系统可用
47c4990 - fix: 修复docker部署场景下时间错误的问题
e23a371 - docs: 添加 compose 注释
1002822 - docs: 标注 Python 最低版本
564350d - feat: 校验 Python 版本
4cc4482 - docs: 添加傻瓜式脚本
757173a - 带麦麦看了心理医生,让她没那么容易陷入负面情绪
39bb99c - 将错别字生成提取到配置,一句一个错别字太烦了!
fe36847 - feat: 超大型重构
e304dd7 - Update README.md
b7cfe6d - feat: 发布第 0.0.2 版本配置模板
ca929d5 - 补充Docker部署文档
1e97120 - 补充Docker部署文档
25f7052 - fix: 修复兼容性选项和目前第一个版本之间的版本间隙 0.0.0 版,并将所有的直接退出修改为抛出异常
c5bdc4f - 防ipv6炸虽然小概率事件
d86610d - fix: 修复不能加载环境变量的问题
2306ebf - feat: 因为判断临界版本范围比较麻烦,增加 notice 字段,删除原本的判断逻辑(存在故障)
dd09576 - fix: 修复 TypeError: BotConfig.convert_to_specifierset() takes 1 positional argument but 2 were given
18f839b - fix: 修复 missing 1 required positional argument: 'INNER_VERSION'
6adb5ed - 调整一些细节docker部署时可选数据库账密
07f48e9 - fix: 利用filter来过滤环境变量避免直接删除key造成的 RuntimeError: dictionary changed size during iteration
5856074 - fix: 修复无法进行基础设置的问题
32aa032 - feat: 发布 0.0.1 版本的配置文件
edc07ac - feat: 重构配置加载器,增加配置文件版本控制和程序兼容能力
0f492ed - fix: 修复 BASE_URL/KEY 组合检查中被 GPG_KEY 干扰的问题

View File

@ -1,23 +1,51 @@
[project]
name = "Megbot"
name = "MaiMaiBot"
version = "0.1.0"
description = "New Bot Project"
description = "MaiMaiBot"
[tool.nonebot]
plugins = ["src.plugins.chat"]
plugin_dirs = ["src/plugins"]
[tool.ruff]
# 设置 Python 版本
target-version = "py39"
include = ["*.py"]
# 行长度设置
line-length = 120
[tool.ruff.lint]
fixable = ["ALL"]
unfixable = []
# 如果一个变量的名称以下划线开头,即使它未被使用,也不应该被视为错误或警告。
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
# 启用的规则
select = [
"E", # pycodestyle 错误
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"E", # pycodestyle 错误
"F", # pyflakes
"B", # flake8-bugbear
]
# 行长度设置
line-length = 88
ignore = ["E711"]
[tool.ruff.format]
docstring-code-format = true
indent-style = "space"
# 使用双引号表示字符串
quote-style = "double"
# 尊重魔法尾随逗号
# 例如:
# items = [
# "apple",
# "banana",
# "cherry",
# ]
skip-magic-trailing-comma = false
# 自动检测合适的换行符
line-ending = "auto"

Binary file not shown.

12
run.py
View File

@ -128,13 +128,17 @@ if __name__ == "__main__":
)
os.system("cls")
if choice == "1":
install_napcat()
install_mongodb()
confirm = input("首次安装将下载并配置所需组件\n1.确认\n2.取消\n")
if confirm == "1":
install_napcat()
install_mongodb()
else:
print("已取消安装")
elif choice == "2":
run_maimbot()
choice = input("是否启动推理可视化y/N").upper()
choice = input("是否启动推理可视化?(未完善)(y/N").upper()
if choice == "Y":
run_cmd(r"python src\gui\reasoning_gui.py")
choice = input("是否启动记忆可视化?(y/N").upper()
choice = input("是否启动记忆可视化?(未完善)(y/N").upper()
if choice == "Y":
run_cmd(r"python src/plugins/memory_system/memory_manual_build.py")

View File

@ -1,25 +1,47 @@
from typing import Optional
from pymongo import MongoClient
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None):
if username and password:
def __init__(
self,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
):
if uri and uri.startswith("mongodb://"):
# 优先使用URI连接
self.client = MongoClient(uri)
elif username and password:
# 如果有用户名和密码,使用认证连接
# TODO: 复杂情况直接支持URI吧
self.client = MongoClient(host, port, username=username, password=password, authSource=auth_source)
self.client = MongoClient(
host, port, username=username, password=password, authSource=auth_source
)
else:
# 否则使用无认证连接
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None) -> "Database":
def initialize(
cls,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source)
cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance
@classmethod
@ -27,24 +49,3 @@ class Database:
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
#测试用
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息
subsequent_messages = list(self.db.messages.find({
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并
messages = [random_message] + subsequent_messages
return messages

View File

@ -7,7 +7,7 @@ from datetime import datetime
from typing import Dict, List
from loguru import logger
from typing import Optional
from pymongo import MongoClient
from ..common.database import Database
import customtkinter as ctk
from dotenv import load_dotenv
@ -28,38 +28,6 @@ else:
logger.error("未找到环境配置文件")
sys.exit(1)
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None):
if username and password:
self.client = MongoClient(
host=host,
port=port,
username=username,
password=password,
authSource=auth_source or 'admin'
)
else:
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source)
return cls._instance
@classmethod
def get_instance(cls) -> "Database":
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
class ReasoningGUI:
def __init__(self):
# 记录启动时间戳转换为Unix时间戳
@ -83,11 +51,19 @@ class ReasoningGUI:
except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...")
try:
Database.initialize("127.0.0.1", 27017, "maimai_bot")
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance().db
logger.success("数据库初始化成功")
except Exception:
logger.exception(f"数据库初始化失败")
logger.exception("数据库初始化失败")
sys.exit(1)
# 存储群组数据
@ -342,7 +318,7 @@ class ReasoningGUI:
'group_id': self.selected_group_id
})
except Exception:
logger.exception(f"自动更新出错")
logger.exception("自动更新出错")
# 每5秒更新一次
time.sleep(5)
@ -359,12 +335,13 @@ class ReasoningGUI:
def main():
"""主函数"""
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
app = ReasoningGUI()

View File

@ -1,9 +1,10 @@
import asyncio
import time
import os
from loguru import logger
from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent
from nonebot.typing import T_State
from ...common.database import Database
@ -15,6 +16,7 @@ from .config import global_config
from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager
from .willing_manager import willing_manager
from .chat_stream import chat_manager
from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot
from .message_sender import message_manager, message_sender
@ -31,12 +33,13 @@ driver = get_driver()
config = driver.config
Database.initialize(
host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME,
username=config.MONGODB_USERNAME,
password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
logger.success("初始化数据库成功")
@ -47,8 +50,8 @@ emoji_manager.initialize()
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
# 创建机器人实例
chat_bot = ChatBot()
# 注册消息处理器
group_msg = on_message(priority=5)
# 注册消息处理器
msg_in = on_message(priority=5)
# 创建定时任务
scheduler = require("nonebot_plugin_apscheduler").scheduler
@ -96,10 +99,12 @@ async def _(bot: Bot):
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
logger.success("-----------开始偷表情包!-----------")
asyncio.create_task(chat_manager._initialize())
asyncio.create_task(chat_manager._auto_save_task())
@group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
@msg_in.handle()
async def _(bot: Bot, event: MessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
@ -121,9 +126,9 @@ async def build_memory_task():
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
# await hippocampus.operation_forget_topic(percentage=0.1)
# print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")

View File

@ -1,27 +1,35 @@
import re
import time
from random import random
import asyncio
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from nonebot.adapters.onebot.v11 import (
Bot,
GroupMessageEvent,
MessageEvent,
PrivateMessageEvent,
)
from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config
from .cq_code import CQCode # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator
from .message import (
Message,
Message_Sending,
Message_Thinking, # 导入 Message_Thinking 类
MessageSet,
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import (
MessageRecvCQ,
)
from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager
from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils import calculate_typing_time, is_mentioned_bot_in_message
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot:
@ -41,194 +49,266 @@ class ChatBot:
if not self._started:
self._started = True
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的消息"""
async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
"""处理收到的消息"""
if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例
# 用户屏蔽,不区分私聊/群聊
if event.user_id in global_config.ban_user_id:
return
group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
# 处理私聊消息
if isinstance(event, PrivateMessageEvent):
if not global_config.enable_friend_chat: # 私聊过滤
return
else:
try:
user_info = UserInfo(
user_id=event.user_id,
user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"],
user_cardname=None,
platform="qq",
)
except Exception as e:
logger.error(f"获取陌生人信息失败: {e}")
return
logger.debug(user_info)
await relationship_manager.update_relationship(user_id=event.user_id, data=sender_info)
await relationship_manager.update_relationship_value(user_id=event.user_id, relationship_value=0.5)
# group_info = GroupInfo(group_id=0, group_name="私聊", platform="qq")
group_info = None
message = Message(
group_id=event.group_id,
user_id=event.user_id,
# 处理群聊消息
else:
# 白名单设定由nontbot侧完成
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
user_info = UserInfo(
user_id=event.user_id,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform="qq",
)
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
# group_info = await bot.get_group_info(group_id=event.group_id)
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
message_cq = MessageRecvCQ(
message_id=event.message_id,
user_cardname=sender_info['card'],
user_info=user_info,
raw_message=str(event.original_message),
plain_text=event.get_plaintext(),
group_info=group_info,
reply_message=event.reply,
platform="qq",
)
await message.initialize()
message_json = message_cq.to_dict()
# 进入maimbot
message = MessageRecv(message_json)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 消息过滤涉及到config有待更新
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo
)
message.update_chat_stream(chat)
await relationship_manager.update_relationship(
chat_stream=chat,
)
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=0.5)
await message.process()
# 过滤词
for word in global_config.ban_words:
if word in message.detailed_plain_text:
if word in message.processed_plain_text:
logger.info(
f"[{message.group_name}]{message.user_nickname}:{message.processed_plain_text}")
f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{userinfo.user_nickname}:{message.processed_plain_text}"
)
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
# 正则表达式过滤
for pattern in global_config.ban_msgs_regex:
if re.search(pattern, message.raw_message):
logger.info(
f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{message.user_nickname}:{message.raw_message}"
)
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic = ''
interested_rate = 0
topic = ""
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
logger.debug(f"{message.processed_plain_text}"
f"的激活度:{interested_rate}")
logger.debug(f"{message.processed_plain_text}的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
await self.storage.store_message(message, chat, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_message(message)
if is_mentioned:
#如果被@等待下文10秒
await asyncio.sleep(10)
logger.info(f"被@,等待下文")
reply_probability = willing_manager.change_reply_willing_received(
event.group_id,
topic[0] if topic else None,
is_mentioned,
global_config,
event.user_id,
message.is_emoji,
interested_rate
reply_probability = await willing_manager.change_reply_willing_received(
chat_stream=chat,
topic=topic[0] if topic else None,
is_mentioned_bot=is_mentioned,
config=global_config,
is_emoji=message.is_emoji,
interested_rate=interested_rate,
)
current_willing = willing_manager.get_willing(event.group_id)
current_willing = willing_manager.get_willing(chat_stream=chat)
logger.info(
f"[{current_time}][{message.group_name}]{message.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]")
f"[{current_time}][{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{chat.user_info.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]"
)
response = ""
response = None
if random() < reply_probability:
tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point)
thinking_message = Message_Thinking(message=message, message_id=think_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform,
)
thinking_time_point = round(time.time(), 2)
think_id = "mt" + str(thinking_time_point)
thinking_message = MessageThinking(
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
reply=message,
)
message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(thinking_message.group_id)
willing_manager.change_reply_willing_sent(chat)
response, raw_content = await self.gpt.generate_response(message)
# print(f"response: {response}")
if response:
container = message_manager.get_container(event.group_id)
# print(f"有response: {response}")
container = message_manager.get_container(chat.stream_id)
thinking_message = None
# 找到message,删除
# print(f"开始找思考消息")
for msg in container.messages:
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
if isinstance(msg, MessageThinking) and msg.message_info.message_id == think_id:
# print(f"找到思考消息: {msg}")
thinking_message = msg
container.messages.remove(msg)
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
break
# 如果找不到思考消息,直接返回
if not thinking_message:
logger.warning(f"未找到对应的思考消息,可能已超时被移除")
logger.warning("未找到对应的思考消息,可能已超时被移除")
return
# 记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(event.group_id, global_config.BOT_QQ,
think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
message_set = MessageSet(chat, think_id)
# 计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
mark_head = False
for msg in response:
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
# 通过时间改变时间戳
typing_time = calculate_typing_time(msg)
logger.debug(f"typing_time: {typing_time}")
accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
timepoint = thinking_time_point + accu_typing_time
message_segment = Seg(type="text", data=msg)
# logger.debug(f"message_segment: {message_segment}")
bot_message = MessageSending(
message_id=think_id,
raw_message=msg,
plain_text=msg,
processed_plain_text=msg,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=timepoint, # 记录了回复生成的时间
thinking_start_time=thinking_start_time, # 记录了思考开始的时间
reply_message_id=message.message_id
chat_stream=chat,
bot_user_info=bot_user_info,
sender_info=userinfo,
message_segment=message_segment,
reply=message,
is_head=not mark_head,
is_emoji=False,
)
await bot_message.initialize()
print(f"bot_message: {bot_message}")
if not mark_head:
bot_message.is_head = True
mark_head = True
print(f"添加消息到message_set: {bot_message}")
message_set.add_message(bot_message)
# message_set 可以直接加入 message_manager
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
logger.debug("添加message_set到message_manager")
message_manager.add_message(message_set)
bot_response_time = tinking_time_point
bot_response_time = thinking_time_point
if random() < global_config.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
# 检查是否 <没有找到> emoji
if emoji_raw != None:
emoji_path, discription = emoji_raw
emoji_path, description = emoji_raw
emoji_cq = CQCode.create_emoji_cq(emoji_path)
emoji_cq = image_path_to_base64(emoji_path)
if random() < 0.5:
bot_response_time = tinking_time_point - 1
bot_response_time = thinking_time_point - 1
else:
bot_response_time = bot_response_time + 1
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,
processed_plain_text=emoji_cq,
detailed_plain_text=discription,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=bot_response_time,
message_segment = Seg(type="emoji", data=emoji_cq)
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
sender_info=userinfo,
message_segment=message_segment,
reply=message,
is_head=False,
is_emoji=True,
translate_cq=False,
thinking_start_time=thinking_start_time,
# reply_message_id=message.message_id
)
await bot_message.initialize()
message_manager.add_message(bot_message)
emotion = await self.gpt._get_emotion_tags(raw_content)
logger.debug(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict = {
'happy': 0.5,
'angry': -1,
'sad': -0.5,
'surprised': 0.2,
'disgusted': -1.5,
'fearful': -0.7,
'neutral': 0.1
"happy": 0.5,
"angry": -1,
"sad": -0.5,
"surprised": 0.2,
"disgusted": -1.5,
"fearful": -0.7,
"neutral": 0.1,
}
await relationship_manager.update_relationship_value(message.user_id,
relationship_value=valuedict[emotion[0]])
await relationship_manager.update_relationship_value(
chat_stream=chat, relationship_value=valuedict[emotion[0]]
)
# 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
# willing_manager.change_reply_willing_after_sent(event.group_id)
# willing_manager.change_reply_willing_after_sent(
# chat_stream=chat
# )
# 创建全局ChatBot实例

View File

@ -0,0 +1,226 @@
import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from .message_base import GroupInfo, UserInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = (
data.get("create_time", int(time.time())) if data else int(time.time())
)
self.last_active_time = (
data.get("last_active_time", self.create_time) if data else self.create_time
)
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = (
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
)
group_info = (
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
)
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
group_info=group_info,
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
self.saved = False
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
# # 启动自动保存任务
# asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
def _generate_stream_id(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
if group_info:
# 组合关键信息
components = [
platform,
str(group_info.group_id)
]
else:
components = [
platform,
str(user_info.user_id),
"private"
]
# 使用MD5生成唯一ID
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息可选
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream=copy.deepcopy(stream)
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
# 创建全局单例
chat_manager = ChatManager()

View File

@ -1,6 +1,7 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Dict, List, Optional
import tomli
from loguru import logger
@ -12,10 +13,12 @@ from packaging.specifiers import SpecifierSet, InvalidSpecifier
@dataclass
class BotConfig:
"""机器人配置类"""
INNER_VERSION: Version = None
BOT_QQ: Optional[int] = 1
BOT_NICKNAME: Optional[str] = None
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
# 消息处理相关配置
MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度
@ -34,8 +37,7 @@ class BotConfig:
ban_user_id = set()
build_memory_interval: int = 30 # 记忆构建间隔(秒)
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
EMOJI_SAVE: bool = True # 偷表情包
@ -43,6 +45,7 @@ class BotConfig:
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
ban_words = set()
ban_msgs_regex = set()
max_response_length: int = 1024 # 最大回复长度
@ -64,6 +67,8 @@ class BotConfig:
enable_advance_output: bool = False # 是否启用高级输出
enable_kuuki_read: bool = True # 是否启用读空气功能
enable_debug_output: bool = False # 是否启用调试输出
enable_friend_chat: bool = False # 是否启用好友聊天
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
@ -81,23 +86,31 @@ class BotConfig:
PROMPT_PERSONALITY = [
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
"是一个女大学生,你有黑色头发,你会刷小红书",
"是一个女大学生你会刷b站对ACG文化感兴趣"
"是一个女大学生你会刷b站对ACG文化感兴趣",
]
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
PROMPT_SCHEDULE_GEN = "一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
PERSONALITY_1: float = 0.6 # 第一种人格概率
PERSONALITY_2: float = 0.3 # 第二种人格概率
PERSONALITY_3: float = 0.1 # 第三种人格概率
PERSONALITY_1: float = 0.6 # 第一种人格概率
PERSONALITY_2: float = 0.3 # 第二种人格概率
PERSONALITY_3: float = 0.1 # 第三种人格概率
memory_ban_words: list = field(default_factory=lambda: ['表情包', '图片', '回复', '聊天记录']) # 添加新的配置项默认值
build_memory_interval: int = 600 # 记忆构建间隔(秒)
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
memory_compress_rate: float = 0.1 # 记忆压缩率
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
@staticmethod
def get_config_dir() -> str:
"""获取配置文件目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
config_dir = os.path.join(root_dir, 'config')
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
config_dir = os.path.join(root_dir, "config")
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return config_dir
@ -113,11 +126,8 @@ class BotConfig:
try:
converted = SpecifierSet(value)
except InvalidSpecifier as e:
logger.error(
f"{value} 分类使用了错误的版本约束表达式\n",
"请阅读 https://semver.org/lang/zh-CN/ 修改代码"
)
except InvalidSpecifier:
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
exit(1)
return converted
@ -131,12 +141,12 @@ class BotConfig:
Version
"""
if 'inner' in toml:
if "inner" in toml:
try:
config_version: str = toml["inner"]["version"]
except KeyError as e:
logger.error(f"配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
else:
toml["inner"] = {"version": "0.0.0"}
config_version = toml["inner"]["version"]
@ -149,7 +159,7 @@ class BotConfig:
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
"本项目在不同的版本下有不同的模板,请注意识别"
)
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n")
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
return ver
@ -159,26 +169,26 @@ class BotConfig:
config = cls()
def personality(parent: dict):
personality_config = parent['personality']
personality = personality_config.get('prompt_personality')
personality_config = parent["personality"]
personality = personality_config.get("prompt_personality")
if len(personality) >= 2:
logger.debug(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY = personality_config.get('prompt_personality', config.PROMPT_PERSONALITY)
config.PROMPT_PERSONALITY = personality_config.get("prompt_personality", config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN = personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)
config.PROMPT_SCHEDULE_GEN = personality_config.get("prompt_schedule", config.PROMPT_SCHEDULE_GEN)
if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.PERSONALITY_1 = personality_config.get('personality_1_probability', config.PERSONALITY_1)
config.PERSONALITY_2 = personality_config.get('personality_2_probability', config.PERSONALITY_2)
config.PERSONALITY_3 = personality_config.get('personality_3_probability', config.PERSONALITY_3)
config.PERSONALITY_1 = personality_config.get("personality_1_probability", config.PERSONALITY_1)
config.PERSONALITY_2 = personality_config.get("personality_2_probability", config.PERSONALITY_2)
config.PERSONALITY_3 = personality_config.get("personality_3_probability", config.PERSONALITY_3)
def emoji(parent: dict):
emoji_config = parent["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt', config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get('auto_save', config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get('enable_check', config.EMOJI_CHECK)
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get("auto_save", config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
def cq_code(parent: dict):
cq_code_config = parent["cq_code"]
@ -191,12 +201,16 @@ class BotConfig:
config.BOT_QQ = int(bot_qq)
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
if config.INNER_VERSION in SpecifierSet(">=0.0.5"):
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
def response(parent: dict):
response_config = parent["response"]
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability",
config.MODEL_R1_DISTILL_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get(
"model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY
)
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
def model(parent: dict):
@ -213,7 +227,7 @@ class BotConfig:
"llm_emotion_judge",
"vlm",
"embedding",
"moderation"
"moderation",
]
for item in config_list:
@ -222,13 +236,7 @@ class BotConfig:
# base_url 的例子: SILICONFLOW_BASE_URL
# key 的例子: SILICONFLOW_KEY
cfg_target = {
"name": "",
"base_url": "",
"key": "",
"pri_in": 0,
"pri_out": 0
}
cfg_target = {"name": "", "base_url": "", "key": "", "pri_in": 0, "pri_out": 0}
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
cfg_target = cfg_item
@ -247,7 +255,7 @@ class BotConfig:
cfg_target[i] = cfg_item[i]
except KeyError as e:
logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
provider = cfg_item.get("provider")
if provider is None:
@ -272,12 +280,17 @@ class BotConfig:
if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.thinking_timeout = msg_config.get("thinking_timeout", config.thinking_timeout)
config.response_willing_amplifier = msg_config.get("response_willing_amplifier",
config.response_willing_amplifier)
config.response_interested_rate_amplifier = msg_config.get("response_interested_rate_amplifier",
config.response_interested_rate_amplifier)
config.response_willing_amplifier = msg_config.get(
"response_willing_amplifier", config.response_willing_amplifier
)
config.response_interested_rate_amplifier = msg_config.get(
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
)
config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate)
if config.INNER_VERSION in SpecifierSet(">=0.0.6"):
config.ban_msgs_regex = msg_config.get("ban_msgs_regex", config.ban_msgs_regex)
def memory(parent: dict):
memory_config = parent["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
@ -287,6 +300,11 @@ class BotConfig:
if config.INNER_VERSION in SpecifierSet(">=0.0.4"):
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
if config.INNER_VERSION in SpecifierSet(">=0.0.7"):
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage)
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
def mood(parent: dict):
mood_config = parent["mood"]
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
@ -303,10 +321,12 @@ class BotConfig:
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
config.chinese_typo_tone_error_rate = chinese_typo_config.get("tone_error_rate",
config.chinese_typo_tone_error_rate)
config.chinese_typo_word_replace_rate = chinese_typo_config.get("word_replace_rate",
config.chinese_typo_word_replace_rate)
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
"tone_error_rate", config.chinese_typo_tone_error_rate
)
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
"word_replace_rate", config.chinese_typo_word_replace_rate
)
def groups(parent: dict):
groups_config = parent["groups"]
@ -318,6 +338,9 @@ class BotConfig:
others_config = parent["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
if config.INNER_VERSION in SpecifierSet(">=0.0.7"):
config.enable_debug_output = others_config.get("enable_debug_output", config.enable_debug_output)
config.enable_friend_chat = others_config.get("enable_friend_chat", config.enable_friend_chat)
# 版本表达式:>=1.0.0,<2.0.0
# 允许字段func: method, support: str, notice: str, necessary: bool
@ -325,61 +348,19 @@ class BotConfig:
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
# 正常执行程序,但是会看到这条自定义提示
include_configs = {
"personality": {
"func": personality,
"support": ">=0.0.0"
},
"emoji": {
"func": emoji,
"support": ">=0.0.0"
},
"cq_code": {
"func": cq_code,
"support": ">=0.0.0"
},
"bot": {
"func": bot,
"support": ">=0.0.0"
},
"response": {
"func": response,
"support": ">=0.0.0"
},
"model": {
"func": model,
"support": ">=0.0.0"
},
"message": {
"func": message,
"support": ">=0.0.0"
},
"memory": {
"func": memory,
"support": ">=0.0.0",
"necessary": False
},
"mood": {
"func": mood,
"support": ">=0.0.0"
},
"keywords_reaction": {
"func": keywords_reaction,
"support": ">=0.0.2",
"necessary": False
},
"chinese_typo": {
"func": chinese_typo,
"support": ">=0.0.3",
"necessary": False
},
"groups": {
"func": groups,
"support": ">=0.0.0"
},
"others": {
"func": others,
"support": ">=0.0.0"
}
"personality": {"func": personality, "support": ">=0.0.0"},
"emoji": {"func": emoji, "support": ">=0.0.0"},
"cq_code": {"func": cq_code, "support": ">=0.0.0"},
"bot": {"func": bot, "support": ">=0.0.0"},
"response": {"func": response, "support": ">=0.0.0"},
"model": {"func": model, "support": ">=0.0.0"},
"message": {"func": message, "support": ">=0.0.0"},
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
"mood": {"func": mood, "support": ">=0.0.0"},
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
"groups": {"func": groups, "support": ">=0.0.0"},
"others": {"func": others, "support": ">=0.0.0"},
}
# 原地修改,将 字符串版本表达式 转换成 版本对象
@ -391,7 +372,7 @@ class BotConfig:
with open(config_path, "rb") as f:
try:
toml_dict = tomli.load(f)
except(tomli.TOMLDecodeError) as e:
except tomli.TOMLDecodeError as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1)
@ -406,7 +387,7 @@ class BotConfig:
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifierset:
# 如果版本在支持范围内,检查是否存在通知
if 'notice' in include_configs[key]:
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
include_configs[key]["func"](toml_dict)
@ -420,7 +401,7 @@ class BotConfig:
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") == False:
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
@ -454,4 +435,8 @@ global_config = BotConfig.load_config(config_path=bot_config_path)
if not global_config.enable_advance_output:
logger.remove()
pass
# 调试输出功能
if global_config.enable_debug_output:
logger.remove()
logger.add(sys.stdout, level="DEBUG")

View File

@ -1,24 +1,24 @@
import base64
import html
import os
import time
from dataclasses import dataclass
from typing import Dict, Optional
from loguru import logger
from typing import Dict, List, Optional, Union
import requests
# 解析各种CQ码
# 包含CQ码类
import urllib3
from loguru import logger
from nonebot import get_driver
from urllib3.util import create_urllib3_context
from ..models.utils_model import LLM_request
from .config import global_config
from .mapper import emojimapper
from .utils_image import storage_emoji, storage_image
from .utils_user import get_user_nickname
from .message_base import Seg
from .utils_user import get_user_nickname,get_groupname
from .message_base import GroupInfo, UserInfo
driver = get_driver()
config = driver.config
@ -36,8 +36,11 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections, maxsize=maxsize,
block=block, ssl_context=self.ssl_context)
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass
@ -49,52 +52,64 @@ class CQCode:
type: CQ码类型'image', 'at', 'face'
params: CQ码的参数字典
raw_code: 原始CQ码字符串
translated_plain_text: 经过处理如AI翻译后的文本表示
translated_segments: 经过处理后的Seg对象列表
"""
type: str
params: Dict[str, str]
# raw_code: str
group_id: int
user_id: int
group_name: str = ""
user_nickname: str = ""
translated_plain_text: Optional[str] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
translated_segments: Optional[Union[Seg, List[Seg]]] = None
reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None
def __post_init__(self):
"""初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
pass
async def translate(self):
"""根据CQ码类型进行相应的翻译处理"""
if self.type == 'text':
self.translated_plain_text = self.params.get('text', '')
elif self.type == 'image':
if self.params.get('sub_type') == '0':
self.translated_plain_text = await self.translate_image()
def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == "text":
self.translated_segments = Seg(
type="text", data=self.params.get("text", "")
)
elif self.type == "image":
base64_data = self.translate_image()
if base64_data:
if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data)
else:
self.translated_segments = Seg(type="emoji", data=base64_data)
else:
self.translated_plain_text = await self.translate_emoji()
elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', ''))
if user_nickname:
self.translated_plain_text = f"[@{user_nickname}]"
self.translated_segments = Seg(type="text", data="[图片]")
elif self.type == "at":
user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg(
type="text", data=f"[@{user_nickname or '某人'}]"
)
elif self.type == "reply":
reply_segments = self.translate_reply()
if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments)
else:
self.translated_plain_text = "@某人"
elif self.type == 'reply':
self.translated_plain_text = await self.translate_reply()
elif self.type == 'face':
face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]"
self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]"
elif self.type == 'forward':
self.translated_plain_text = await self.translate_forward()
self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face":
face_id = self.params.get("id", "")
self.translated_segments = Seg(
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == "forward":
forward_segments = self.translate_forward()
if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments)
else:
self.translated_segments = Seg(type="text", data="[转发消息]")
else:
self.translated_plain_text = f"[{self.type}]"
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self):
'''
"""
headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8',
@ -103,18 +118,18 @@ class CQCode:
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
'''
"""
# 腾讯专用请求头配置
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36',
'Accept': 'text/html, application/xhtml xml, */*',
'Accept-Encoding': 'gbk, GB2312',
'Accept-Language': 'zh-cn',
'Content-Type': 'application/x-www-form-urlencoded',
'Cache-Control': 'no-cache'
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*",
"Accept-Encoding": "gbk, GB2312",
"Accept-Language": "zh-cn",
"Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache",
}
url = html.unescape(self.params['url'])
if not url.startswith(('http://', 'https://')):
url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")):
return None
# 创建专用会话
@ -130,223 +145,191 @@ class CQCode:
headers=headers,
timeout=15,
allow_redirects=True,
stream=True # 流式传输避免大内存问题
stream=True, # 流式传输避免大内存问题
)
# 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
return None
if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
# 验证内容类型
content_type = response.headers.get('Content-Type', '')
if not content_type.startswith('image/'):
content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 转换为Base64
image_base64 = base64.b64encode(response.content).decode('utf-8')
image_base64 = base64.b64encode(response.content).decode("utf-8")
self.image_base64 = image_base64
return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
if retry == max_retries - 1:
logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5 ** retry) # 指数退避
time.sleep(1.5**retry) # 指数退避
except Exception as e:
logger.exception(f"[未知错误]")
except Exception:
logger.exception("[未知错误]")
return None
return None
async def translate_emoji(self) -> str:
"""处理表情包类型的CQ码"""
if 'url' not in self.params:
return '[表情包]'
base64_str = self.get_img()
if base64_str:
# 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes)
return await self.get_emoji_description(base64_str)
else:
return '[表情包]'
async def translate_image(self) -> str:
"""处理图片类型的CQ码区分普通图片和表情包"""
# 没有url直接返回默认文本
if 'url' not in self.params:
return '[图片]'
base64_str = self.get_img()
if base64_str:
image_bytes = base64.b64decode(base64_str)
storage_image(image_bytes)
return await self.get_image_description(base64_str)
else:
return '[图片]'
def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if "url" not in self.params:
return None
return self.get_img()
async def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述"""
def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表"""
try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]"
except Exception as e:
logger.exception(f"AI接口调用失败: {str(e)}")
return "[表情包]"
if "content" not in self.params:
return None
async def get_image_description(self, image_base64: str) -> str:
"""调用AI接口获取普通图片描述"""
try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字可以使用二次元词汇。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]"
except Exception as e:
logger.exception(f"AI接口调用失败: {str(e)}")
return "[图片]"
async def translate_forward(self) -> str:
"""处理转发消息"""
try:
if 'content' not in self.params:
return '[转发消息]'
# 解析content内容需要先反转义
content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
# 将字符串形式的列表转换为Python对象
content = self.unescape(self.params["content"])
import ast
try:
messages = ast.literal_eval(content)
except ValueError as e:
logger.error(f"解析转发消息内容失败: {str(e)}")
return '[转发消息]'
return None
# 处理每条消息
formatted_messages = []
formatted_segments = []
for msg in messages:
sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户')
# 获取消息内容并使用Message类处理
raw_message = msg.get('raw_message', '')
message_array = msg.get('message', [])
sender = msg.get("sender", {})
nickname = sender.get("card") or sender.get("nickname", "未知用户")
raw_message = msg.get("raw_message", "")
message_array = msg.get("message", [])
if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息
for message_part in message_array:
if message_part.get('type') == 'forward':
content = '[转发消息]'
if message_part.get("type") == "forward":
content_seg = Seg(type="text", data="[转发消息]")
break
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
if raw_message:
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
else:
content_seg = Seg(type="text", data="[空消息]")
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
content_seg = Seg(type="text", data="[空消息]")
formatted_msg = f"{nickname}: {content}"
formatted_messages.append(formatted_msg)
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
formatted_segments.append(content_seg)
formatted_segments.append(Seg(type="text", data="\n"))
# 合并所有消息
combined_messages = '\n'.join(formatted_messages)
logger.debug(f"合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]"
return formatted_segments
except Exception as e:
logger.exception("处理转发消息失败")
return '[转发消息]'
logger.error(f"处理转发消息失败: {str(e)}")
return None
async def translate_reply(self) -> str:
"""处理回复类型的CQ码"""
def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
# 创建Message对象
from .message import Message
if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]'
if self.reply_message is None:
return None
if self.reply_message.sender.user_id:
message_obj = Message(
user_id=self.reply_message.sender.user_id,
message_obj = MessageRecvCQ(
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname),
message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message),
group_id=self.group_id
group_info=GroupInfo(group_id=self.reply_message.group_id),
)
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else:
segments.append(
Seg(
type="text",
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
)
)
segments.append(Seg(type="seglist", data=[message_obj.message_segment]))
segments.append(Seg(type="text", data="]"))
return segments
else:
logger.error("回复消息的sender.user_id为空")
return '[回复某人消息]'
return None
@staticmethod
def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符"""
return text.replace('&#44;', ',') \
.replace('&#91;', '[') \
.replace('&#93;', ']') \
.replace('&amp;', '&')
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
return (
text.replace("&#44;", ",")
.replace("&#91;", "[")
.replace("&#93;", "]")
.replace("&amp;", "&")
)
class CQCode_tool:
@staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
Args:
cq_code: CQ码字典
msg: MessageCQ对象
reply: 回复消息的字典可选
Returns:
@ -354,23 +337,23 @@ class CQCode_tool:
"""
# 处理字典形式的CQ码
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
cq_type = cq_code.get('type', 'text')
cq_type = cq_code.get("type", "text")
params = {}
if cq_type == 'text':
params['text'] = cq_code.get('data', {}).get('text', '')
if cq_type == "text":
params["text"] = cq_code.get("data", {}).get("text", "")
else:
params = cq_code.get('data', {})
params = cq_code.get("data", {})
instance = CQCode(
type=cq_type,
params=params,
group_id=0,
user_id=0,
group_info=msg.message_info.group_info,
user_info=msg.message_info.user_info,
reply_message=reply
)
# 进行翻译处理
await instance.translate()
instance.translate()
return instance
@staticmethod
@ -384,5 +367,64 @@ class CQCode_tool:
"""
return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = (
abs_path.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@staticmethod
def create_emoji_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
cq_code_tool = CQCode_tool()

View File

@ -1,9 +1,11 @@
import asyncio
import base64
import hashlib
import os
import random
import time
import traceback
from typing import Optional
from typing import Optional, Tuple
from loguru import logger
from nonebot import get_driver
@ -11,11 +13,12 @@ from nonebot import get_driver
from ...common.database import Database
from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64
from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
image_manager = ImageManager()
class EmojiManager:
@ -33,7 +36,7 @@ class EmojiManager:
self.db = Database.get_instance()
self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,
temperature=0.8) # 更高的温度更少的token后续可以根据情绪来调整温度
@ -51,8 +54,8 @@ class EmojiManager:
self._initialized = True
# 启动时执行一次完整性检查
self.check_emoji_file_integrity()
except Exception as e:
logger.exception(f"初始化表情管理器失败")
except Exception:
logger.exception("初始化表情管理器失败")
def _ensure_db(self):
"""确保数据库已初始化"""
@ -76,7 +79,6 @@ class EmojiManager:
if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):
@ -88,9 +90,9 @@ class EmojiManager:
{'$inc': {'usage_count': 1}}
)
except Exception as e:
logger.exception(f"记录表情使用失败")
logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text: str) -> Optional[str]:
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
"""根据文本内容获取相关表情包
Args:
text: 输入文本
@ -117,7 +119,7 @@ class EmojiManager:
try:
# 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1}))
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
@ -144,14 +146,14 @@ class EmojiManager:
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3]
top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_3_emojis:
if not top_10_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis)
selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
@ -159,10 +161,11 @@ class EmojiManager:
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
logger.success(
f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})")
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
return selected_emoji['path'], "[ %s ]" % selected_emoji.get('discription', '无描述')
return selected_emoji['path'], "[ %s ]" % selected_emoji.get('description', '无描述')
except Exception as search_error:
logger.error(f"搜索表情包失败: {str(search_error)}")
@ -175,13 +178,14 @@ class EmojiManager:
return None
async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签"""
try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容、上面的文字和表情包所表达的情感,可以使用二次元词汇'
"""获取表情包的标签使用image_manager的描述生成功能"""
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}")
return content
try:
# 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
# 去掉[表情包xxx]的格式,只保留描述内容
description = description.strip('[]').replace('表情包:', '')
return description
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
@ -203,7 +207,7 @@ class EmojiManager:
try:
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
content, _ = await self.llm_emotion_judge.generate_response_async(prompt)
content, _ = await self.llm_emotion_judge.generate_response_async(prompt,temperature=1.5)
logger.info(f"输出描述: {content}")
return content
@ -224,48 +228,108 @@ class EmojiManager:
for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji:
continue
# 压缩图片并获取base64编码
# 获取图片的base64编码和哈希值
image_base64 = image_path_to_base64(image_path)
if image_base64 is None:
os.remove(image_path)
continue
# 获取表情包的描述
discription = await self._get_emoji_discription(image_base64)
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步已存在的表情包到images集合: {filename}")
continue
# 检查是否在images集合中已有描述
existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description:
description = existing_description
else:
# 获取表情包的描述
description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64)
if '' not in check:
os.remove(image_path)
logger.info(f"描述: {discription}")
logger.info(f"描述: {description}")
logger.info(f"描述: {description}")
logger.info(f"其不满足过滤规则,被剔除 {check}")
continue
logger.info(f"check通过 {check}")
if discription is not None:
embedding = await get_embedding(discription)
if description is not None:
embedding = await get_embedding(description)
if description is not None:
embedding = await get_embedding(description)
# 准备数据库记录
emoji_record = {
'filename': filename,
'path': image_path,
'embedding': embedding,
'discription': discription,
'discription': description,
'hash': image_hash,
'timestamp': int(time.time())
}
# 保存到数据库
# 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {discription}")
logger.info(f"描述: {description}")
# 保存到images数据库
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}")
else:
logger.warning(f"跳过表情包: {filename}")
except Exception as e:
logger.exception(f"扫描表情包失败")
except Exception:
logger.exception("扫描表情包失败")
async def _periodic_scan(self, interval_MINS: int = 10):
"""定期扫描新表情包"""
@ -332,5 +396,7 @@ class EmojiManager:
# 创建全局单例
emoji_manager = EmojiManager()

View File

@ -8,7 +8,7 @@ from loguru import logger
from ...common.database import Database
from ..models.utils_model import LLM_request
from .config import global_config
from .message import Message
from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager
from .utils import process_llm_response
@ -19,48 +19,79 @@ config = driver.config
class ResponseGenerator:
def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
self.model_r1 = LLM_request(
model=global_config.llm_reasoning,
temperature=0.7,
max_tokens=1000,
stream=True,
)
self.model_v3 = LLM_request(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance()
self.current_model_type = 'r1' # 默认使用 R1
self.current_model_type = "r1" # 默认使用 R1
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
async def generate_response(
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = 'r1'
self.current_model_type = "r1"
current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
self.current_model_type = 'v3'
elif (
rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3"
current_model = self.model_v3
else:
self.current_model_type = 'r1_distill'
self.current_model_type = "r1_distill"
current_model = self.model_r1_distill
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model(message, current_model)
raw_content=model_response
model_response = await self._generate_response_with_model(
message, current_model
)
raw_content = model_response
# print(f"raw_content: {raw_content}")
# print(f"model_response: {model_response}")
if model_response:
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response = await self._process_response(model_response)
if model_response:
return model_response, raw_content
return None, raw_content
return model_response ,raw_content
return None,raw_content
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
async def _generate_response_with_model(
self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
"""使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}"
if message.user_cardname:
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
sender_name = (
message.chat_stream.user_info.user_nickname
or f"用户{message.chat_stream.user_info.user_id}"
)
if message.chat_stream.user_info.user_cardname:
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
# 获取关系值
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
relationship_value = (
relationship_manager.get_relationship(
message.chat_stream
).relationship_value
if relationship_manager.get_relationship(message.chat_stream)
else 0.0
)
if relationship_value != 0.0:
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
pass
@ -70,7 +101,7 @@ class ResponseGenerator:
message_txt=message.processed_plain_text,
sender_name=sender_name,
relationship_value=relationship_value,
group_id=message.group_id
stream_id=message.chat_stream.stream_id,
)
# 读空气模块 简化逻辑,先停用
@ -94,7 +125,7 @@ class ResponseGenerator:
try:
content, reasoning_content = await model.generate_response(prompt)
except Exception:
logger.exception(f"生成回复时出错")
logger.exception("生成回复时出错")
return None
# 保存到数据库
@ -113,40 +144,57 @@ class ResponseGenerator:
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
content: str, reasoning_content: str,):
def _save_to_db(
self,
message: MessageRecv,
sender_name: str,
prompt: str,
prompt_check: str,
content: str,
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
'group_id': message.group_id,
'user': sender_name,
'message': message.processed_plain_text,
'model': self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
'reasoning': reasoning_content,
'response': content,
'prompt': prompt,
'prompt_check': prompt_check
})
self.db.db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,
"user": sender_name,
"message": message.processed_plain_text,
"model": self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
"reasoning": reasoning_content,
"response": content,
"prompt": prompt,
"prompt_check": prompt_check,
}
)
async def _get_emotion_tags(self, content: str) -> List[str]:
"""提取情感标签"""
try:
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
只输出标签就好不要输出其他内容:
内容{content}
输出
'''
"""
content, _ = await self.model_v25.generate_response(prompt)
content=content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
content = content.strip()
if content in [
"happy",
"angry",
"sad",
"surprised",
"disgusted",
"fearful",
"neutral",
]:
return [content]
else:
return ["neutral"]
except Exception:
logger.exception(f"获取情感标签时出错")
except Exception as e:
print(f"获取情感标签时出错: {e}")
return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
@ -156,6 +204,8 @@ class ResponseGenerator:
processed_response = process_llm_response(content)
# print(f"得到了处理后的llm返回{processed_response}")
return processed_response

View File

@ -1,231 +1,431 @@
import time
import html
import re
import json
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional
from typing import Dict, List, Optional
import urllib3
from loguru import logger
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream, chat_manager
Message = ForwardRef('Message') # 添加这行
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
class Message(MessageBase):
chat_stream: ChatStream=None
reply: Optional['Message'] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
user_cardname: str = None # 用户群昵称
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
reply_message: Dict = None # 存储 回复的 源消息
# 延迟初始化字段
_initialized: bool = False
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
# 状态标志
is_emoji: bool = False
has_emoji: bool = False
translate_cq: bool = True
async def initialize(self):
"""显式异步初始化方法(必须调用)"""
if self._initialized:
return
# 异步获取补充信息
self.group_name = self.group_name or get_groupname(self.group_id)
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
# 消息解析
if self.raw_message:
if not isinstance(self,Message_Sending):
self.message_segments = await self.parse_message_segments(self.raw_message)
self.processed_plain_text = ' '.join(
seg.translated_plain_text
for seg in self.message_segments
)
# 构建详细文本
if self.time is None:
self.time = int(time.time())
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
name = (
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
if self.user_cardname
else f"{self.user_nickname or f'用户{self.user_id}'}"
def __init__(
self,
message_id: str,
time: int,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=time,
group_info=chat_stream.group_info,
user_info=user_info
)
if isinstance(self,Message_Sending) and self.is_emoji:
self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n"
else:
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
# 调用父类初始化
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=None
)
async def parse_message_segments(self, message: str) -> List[CQCode]:
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
self.detailed_plain_text = detailed_plain_text
# 回复消息
self.reply = reply
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: Dict):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
将消息解析为片段列表包括纯文本和CQ码
返回的列表中每个元素都是字典包含
- cq_code_list:分割出的聊天对象包括文本和CQ码
- trans_list:翻译后的对象列表
self.message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
message_segment = message_dict.get('message_segment', {})
if message_segment.get('data','') == '[json]':
# 提取json消息中的展示信息
pattern = r'\[CQ:json,data=(?P<json_data>.+?)\]'
match = re.search(pattern, message_dict.get('raw_message',''))
raw_json = html.unescape(match.group('json_data'))
try:
json_message = json.loads(raw_json)
except json.JSONDecodeError:
json_message = {}
message_segment['data'] = json_message.get('prompt','')
self.message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
self.raw_message = message_dict.get('raw_message')
# 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False
def update_chat_stream(self,chat_stream:ChatStream):
self.chat_stream=chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用因为它包含异步操作
"""
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = []
trans_list = []
self.processed_plain_text = await self._process_message_segments(
self.message_segment
)
self.detailed_plain_text = self._generate_detailed_text()
start = 0
while True:
# 查找下一个CQ码的开始位置
cq_start = message.find('[CQ:', start)
#如果没有cq码直接返回文本内容
if cq_start == -1:
# 如果没有找到更多CQ码添加剩余文本
if start < len(message):
text = message[start:].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
break
# 添加CQ码前的文本
if cq_start > start:
text = message[start:cq_start].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
# 查找CQ码的结束位置
cq_end = message.find(']', cq_start)
if cq_end == -1:
# CQ码未闭合作为普通文本处理
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
#将cq_code解析成字典
cq_code_dict_list.append(parse_cq_code(cq_code))
# 更新start位置到当前CQ码之后
start = cq_end + 1
Args:
segment: 要处理的消息段
# print(f"\033[1;34m[调试信息]\033[0m 提取的消息对象:列表: {cq_code_dict_list}")
#判定是否是表情包消息,以及是否含有表情包
if len(cq_code_dict_list) == 1 and cq_code_dict_list[0]['type'] == 'image':
self.is_emoji = True
self.has_emoji_emoji = True
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
for segment in cq_code_dict_list:
if segment['type'] == 'image' and segment['data'].get('sub_type') == '1':
self.has_emoji_emoji = True
break
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == "text":
return seg.data
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return "[图片]"
elif seg.type == "emoji":
self.is_emoji = True
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return "[表情]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(
f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}"
)
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime(
"%m-%d %H:%M:%S", time.localtime(self.message_info.time)
)
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ""
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
#翻译作为字典的CQ码
for _code_item in cq_code_dict_list:
message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
trans_list.append(message_obj)
return trans_list
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
class Message_Thinking:
"""消息思考类"""
def __init__(self, message: Message,message_id: str):
# 复制原始消息的基本属性
self.group_id = message.group_id
self.user_id = message.user_id
self.user_nickname = message.user_nickname
self.user_cardname = message.user_cardname
self.group_name = message.group_name
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None,
):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=int(time.time()),
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
)
self.message_id = message_id
# 思考状态相关属性
# 处理状态相关属性
self.thinking_start_time = int(time.time())
self.thinking_time = 0
self.interupt=False
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
def update_thinking_time(self) -> float:
"""更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
return self.thinking_time
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == "text":
return seg.data
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return "[图片]"
elif seg.type == "emoji":
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return "[表情]"
elif seg.type == "at":
return f"[@{seg.data}]"
elif seg.type == "reply":
if self.reply and hasattr(self.reply, "processed_plain_text"):
return f"[回复:{self.reply.processed_plain_text}]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(
f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}"
)
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime(
"%m-%d %H:%M:%S", time.localtime(self.message_info.time)
)
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ""
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class Message_Sending(Message):
"""发送中的消息类"""
thinking_start_time: float = None # 思考开始时间
thinking_time: float = None # 思考时间
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
reply_message_id: int = None # 存储 回复的 源消息ID
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional["MessageRecv"] = None,
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply,
)
is_head: bool = False # 是否是头部消息
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
return self.thinking_time
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
sender_info: UserInfo, # 用来记录发送者信息,用于私聊回复
message_segment: Seg,
reply: Optional["MessageRecv"] = None,
is_head: bool = False,
is_emoji: bool = False,
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
)
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
def set_reply(self, reply: Optional["MessageRecv"]) -> None:
"""设置回复消息"""
if reply:
self.reply = reply
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=reply.message_info.message_id),
self.message_segment,
],
)
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(
self.message_segment
)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
is_head: bool = False,
is_emoji: bool = False,
) -> "MessageSending":
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji,
)
def to_dict(self):
ret = super().to_dict()
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
return ret
def is_private_message(self) -> bool:
"""判断是否为私聊消息"""
return (
self.message_info.group_info is None
or self.message_info.group_info.group_id is None
)
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str):
self.group_id = group_id
self.user_id = user_id
def __init__(self, chat_stream: ChatStream, message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: List[Message_Sending] = [] # 修改类型标注
self.messages: List[MessageSending] = []
self.time = round(time.time(), 2)
def add_message(self, message: Message_Sending) -> None:
"""添加消息到集合只接受Message_Sending类型"""
if not isinstance(message, Message_Sending):
raise TypeError("MessageSet只能添加Message_Sending类型的消息")
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
# 按时间排序
self.messages.sort(key=lambda x: x.time)
self.messages.sort(key=lambda x: x.message_info.time)
def get_message_by_index(self, index: int) -> Optional[Message_Sending]:
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]:
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
# 使用二分查找找到最接近的消息
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].time < target_time:
if self.messages[mid].message_info.time < target_time:
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: Message_Sending) -> bool:
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)
@ -237,6 +437,3 @@ class MessageSet:
def __len__(self) -> int:
return len(self.messages)

View File

@ -0,0 +1,186 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Dict
@dataclass
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data base64 字符串
- 对于 seglist 类型data Seg 列表
translated_data: 经过翻译处理的数据可选
"""
type: str
data: Union[str, List['Seg']]
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
@classmethod
def from_dict(cls, data: Dict) -> 'Seg':
"""从字典创建Seg实例"""
type=data.get('type')
data=data.get('data')
if type == 'seglist':
data = [Seg.from_dict(seg) for seg in data]
return cls(
type=type,
data=data
)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {'type': self.type}
if self.type == 'seglist':
result['data'] = [seg.to_dict() for seg in self.data]
else:
result['data'] = self.data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
group_id=data.get('group_id'),
group_name=data.get('group_name',None)
)
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
user_id=data.get('user_id'),
user_nickname=data.get('user_nickname',None),
user_cardname=data.get('user_cardname',None)
)
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Union[str,int,None] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo(**data.get('group_info', {}))
user_info = UserInfo(**data.get('user_info', {}))
return cls(
platform=data.get('platform'),
message_id=data.get('message_id'),
time=data.get('time'),
group_info=group_info,
user_info=user_info
)
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {
'message_info': self.message_info.to_dict(),
'message_segment': self.message_segment.to_dict()
}
if self.raw_message is not None:
result['raw_message'] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> 'MessageBase':
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message',None)
return cls(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)

View File

@ -0,0 +1,173 @@
import time
from dataclasses import dataclass
from typing import Dict, Optional
import urllib3
from .cq_code import cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识默认为"qq"
"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
platform: str = "qq"
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
message_segment=None,
raw_message=None
)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
raw_message: str,
group_info: Optional[GroupInfo] = None,
platform: str = "qq",
reply_message: Optional[Dict] = None,
):
# 调用父类初始化
super().__init__(message_id, user_info, group_info, platform)
# 私聊消息不携带group_info
if group_info is None:
pass
elif group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
self.raw_message = raw_message
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find('[CQ:', start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
if cq_start > start:
text = message[start:cq_start].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(
self,
data: Dict
):
# 调用父类初始化
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg.from_dict(data.get('message_segment', {}))
super().__init__(
message_info.message_id,
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform
)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, ) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 处理消息段
if self.message_segment.type == 'seglist':
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text':
return str(seg.data)
elif seg.type == 'image':
return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == 'emoji':
return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@ -5,10 +5,10 @@ from typing import Dict, List, Optional, Union
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool
from .message import Message, Message_Sending, Message_Thinking, MessageSet
from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageRecv, MessageSet
from .storage import MessageStorage
from .utils import calculate_typing_time
from .config import global_config
@ -24,64 +24,61 @@ class Message_Sender:
"""设置当前bot实例"""
self._current_bot = bot
async def send_group_message(
self,
group_id: int,
send_text: str,
auto_escape: bool = False,
reply_message_id: int = None,
at_user_id: int = None
async def send_message(
self,
message: MessageSending,
) -> None:
"""发送消息"""
if not self._current_bot:
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例")
message = send_text
# 如果需要回复
if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message
# 如果需要at
# if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message
typing_time = calculate_typing_time(message)
if typing_time > 10:
typing_time = 10
await asyncio.sleep(typing_time)
# 发送消息
try:
await self._current_bot.send_group_msg(
group_id=group_id,
message=message,
auto_escape=auto_escape
)
logger.debug(f"发送消息{message}成功")
except Exception as e:
logger.exception(f"发送消息{message}失败")
if isinstance(message, MessageSending):
message_json = message.to_dict()
message_send = MessageSendCQ(data=message_json)
# logger.debug(message_send.message_info,message_send.raw_message)
if (
message_send.message_info.group_info
and message_send.message_info.group_info.group_id
):
try:
await self._current_bot.send_group_msg(
group_id=message.message_info.group_info.group_id,
message=message_send.raw_message,
auto_escape=False,
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e:
logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
else:
try:
logger.debug(message.message_info.user_info)
await self._current_bot.send_private_msg(
user_id=message.sender_info.user_id,
message=message_send.raw_message,
auto_escape=False,
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e:
logger.error(f"发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
class MessageContainer:
"""单个的发送/思考消息容器"""
"""单个聊天流的发送/思考消息容器"""
def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id
def __init__(self, chat_id: str, max_size: int = 100):
self.chat_id = chat_id
self.max_size = max_size
self.messages = []
self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]:
def get_timeout_messages(self) -> List[MessageSending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time()
timeout_messages = []
for msg in self.messages:
if isinstance(msg, Message_Sending):
if isinstance(msg, MessageSending):
if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg)
@ -90,11 +87,11 @@ class MessageContainer:
return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
"""获取thinking_start_time最早的消息对象"""
if not self.messages:
return None
earliest_time = float('inf')
earliest_time = float("inf")
earliest_message = None
for msg in self.messages:
msg_time = msg.thinking_start_time
@ -103,16 +100,15 @@ class MessageContainer:
earliest_message = msg
return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
"""添加消息到队列"""
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
if isinstance(message, MessageSet):
for single_message in message.messages:
self.messages.append(single_message)
else:
self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False"""
try:
if message in self.messages:
@ -120,101 +116,109 @@ class MessageContainer:
return True
return False
except Exception:
logger.exception(f"移除消息时发生错误")
logger.exception("移除消息时发生错误")
return False
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
"""获取所有消息"""
return list(self.messages)
class MessageManager:
"""管理所有的消息容器"""
"""管理所有聊天流的消息容器"""
def __init__(self):
self.containers: Dict[int, MessageContainer] = {}
self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
self.storage = MessageStorage()
self._running = True
def get_container(self, group_id: int) -> MessageContainer:
"""获取或创建的消息容器"""
if group_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id)
return self.containers[group_id]
def get_container(self, chat_id: str) -> MessageContainer:
"""获取或创建聊天流的消息容器"""
if chat_id not in self.containers:
self.containers[chat_id] = MessageContainer(chat_id)
return self.containers[chat_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
container = self.get_container(message.group_id)
def add_message(
self, message: Union[MessageThinking, MessageSending, MessageSet]
) -> None:
chat_stream = message.chat_stream
if not chat_stream:
raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id)
container.add_message(message)
async def process_group_messages(self, group_id: int):
"""处理群消息"""
# if int(time.time() / 3) == time.time() / 3:
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id)
async def process_chat_messages(self, chat_id: str):
"""处理聊天流消息"""
container = self.get_container(chat_id)
if container.has_messages():
# 最早的对象,可能是思考消息,也可能是发送消息
message_earliest = container.get_earliest_message() # 一个message_thinking or message_sending
# print(f"处理有message的容器chat_id: {chat_id}")
message_earliest = container.get_earliest_message()
# 如果是思考消息
if isinstance(message_earliest, Message_Thinking):
# 优先等待这条消息
if isinstance(message_earliest, MessageThinking):
message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time
print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True)
print(
f"消息正在思考中,已思考{int(thinking_time)}\r",
end="",
flush=True,
)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest)
else: # 如果不是message_thinking就只能是message_sending
logger.debug(f"消息'{message_earliest.processed_plain_text}'正在发送中")
# 直接发,等什么呢
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
auto_escape=False,
reply_message_id=message_earliest.reply_message_id)
else:
if (
message_earliest.is_head
and message_earliest.update_thinking_time() > 30
and not message_earliest.is_private_message() # 避免在私聊时插入reply
):
await message_sender.send_message(message_earliest.set_reply())
else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
auto_escape=False)
# 移除消息
if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None)
await message_sender.send_message(message_earliest)
await message_earliest.process()
print(
f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中"
)
await self.storage.store_message(
message_earliest, message_earliest.chat_stream, None
)
container.remove_message(message_earliest)
# 获取并处理超时消息
message_timeout = container.get_timeout_messages() # 也许是一堆message_sending
message_timeout = container.get_timeout_messages()
if message_timeout:
logger.warning(f"发现{len(message_timeout)}条超时消息")
for msg in message_timeout:
if msg == message_earliest:
continue # 跳过已经处理过的消息
continue
try:
# 发送
if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, msg.processed_plain_text,
auto_escape=False,
reply_message_id=msg.reply_message_id)
if (
msg.is_head
and msg.update_thinking_time() > 30
and not message_earliest.is_private_message() # 避免在私聊时插入reply
):
await message_sender.send_message(msg.set_reply())
else:
await message_sender.send_group_message(group_id, msg.processed_plain_text,
auto_escape=False)
await message_sender.send_message(msg)
# 如果是表情包,则替换为"[表情包]"
if msg.is_emoji:
msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None)
# if msg.is_emoji:
# msg.processed_plain_text = "[表情包]"
await msg.process()
await self.storage.store_message(msg, msg.chat_stream, None)
# 安全地移除消息
if not container.remove_message(msg):
logger.warning("尝试删除不存在的消息")
except Exception:
logger.exception(f"处理超时消息时发生错误")
logger.exception("处理超时消息时发生错误")
continue
async def start_processor(self):
@ -222,8 +226,8 @@ class MessageManager:
while self._running:
await asyncio.sleep(1)
tasks = []
for group_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id))
for chat_id in self.containers.keys():
tasks.append(self.process_chat_messages(chat_id))
await asyncio.gather(*tasks)

View File

@ -9,6 +9,7 @@ from ..moods.moods import MoodManager
#from ..schedule.schedule_generator import bot_schedule
from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import chat_manager
class PromptBuilder:
@ -17,11 +18,13 @@ class PromptBuilder:
self.activate_messages = ''
self.db = Database.get_instance()
async def _build_prompt(self,
message_txt: str,
sender_name: str = "某人",
relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]:
message_txt: str,
sender_name: str = "某人",
relationship_value: float = 0.0,
stream_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt
Args:
@ -70,13 +73,19 @@ class PromptBuilder:
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文
chat_in_group=True
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else:
chat_in_group=False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# 使用新的记忆获取方法
memory_prompt = ''
@ -108,15 +117,10 @@ class PromptBuilder:
# 激活prompt构建
activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
# 检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中
# bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
# is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
# if is_bot:
# is_bot_prompt = '有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认'
# else:
# is_bot_prompt = ''
if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
else:
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
# 关键词检测与反应
keywords_reaction_prompt = ''
@ -126,26 +130,33 @@ class PromptBuilder:
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
keywords_reaction_prompt += rule.get("reaction", "") + ''
# 人格选择
personality = global_config.PROMPT_PERSONALITY
#人格选择
personality=global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2
probability_3 = global_config.PERSONALITY_3
prompt_personality = ''
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)}'
personality_choice = random.random()
if chat_in_group:
prompt_in_group=f"你正在浏览{chat_stream.platform}"
else:
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
if personality_choice < probability_1: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
现在请你给出日常且口语化的回复{keywords_reaction_prompt}
请注意把握群里的聊天内容回复可以有个性'''
elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality += f'''{personality[1]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复{keywords_reaction_prompt}
请注意把握群里的聊天内容回复可以有个性'''
else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality += f'''{personality[2]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复{keywords_reaction_prompt}
请注意把握群里的聊天内容回复可以有个性'''
# 中文高手(新加的好玩功能)
prompt_ger = ''
if random.random() < 0.04:

View File

@ -1,9 +1,10 @@
import asyncio
from loguru import logger
from typing import Optional
from loguru import logger
from ...common.database import Database
from .message_base import UserInfo
from .chat_stream import ChatStream
class Impression:
traits: str = None
@ -15,59 +16,66 @@ class Impression:
class Relationship:
user_id: int = None
# impression: Impression = None
# group_id: int = None
# group_name: str = None
platform: str = None
gender: str = None
age: int = None
nickname: str = None
relationship_value: float = None
saved = False
def __init__(self, user_id: int, data=None, **kwargs):
if isinstance(data, dict):
# 如果输入是字典,使用字典解析
self.user_id = data.get('user_id')
self.gender = data.get('gender')
self.age = data.get('age')
self.nickname = data.get('nickname')
self.relationship_value = data.get('relationship_value', 0.0)
self.saved = data.get('saved', False)
else:
# 如果是直接传入属性值
self.user_id = kwargs.get('user_id')
self.gender = kwargs.get('gender')
self.age = kwargs.get('age')
self.nickname = kwargs.get('nickname')
self.relationship_value = kwargs.get('relationship_value', 0.0)
self.saved = kwargs.get('saved', False)
def __init__(self, chat:ChatStream=None,data:dict=None):
self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
self.platform=chat.platform if chat else data.get('platform','')
self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0) if data else 0
self.age=data.get('age',0) if data else 0
self.gender=data.get('gender','') if data else ''
class RelationshipManager:
def __init__(self):
self.relationships: dict[int, Relationship] = {}
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self,
chat_stream:ChatStream,
data: dict = None,
**kwargs) -> Optional[Relationship]:
"""更新或创建关系
Args:
chat_stream: 聊天流对象
data: 字典格式的数据可选
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if chat_stream.user_info is not None:
user_id = chat_stream.user_info.user_id
platform = chat_stream.user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
async def update_relationship(self, user_id: int, data=None, **kwargs):
# 检查是否在内存中已存在
relationship = self.relationships.get(user_id)
relationship = self.relationships.get(key)
if relationship:
# 如果存在,更新现有对象
if isinstance(data, dict):
for key, value in data.items():
if hasattr(relationship, key) and value is not None:
setattr(relationship, key, value)
else:
for key, value in kwargs.items():
if hasattr(relationship, key) and value is not None:
setattr(relationship, key, value)
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
# 如果不存在,创建新对象
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id,
**kwargs)
self.relationships[user_id] = relationship
# 更新 id_name_nickname_table
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表
if chat_stream.user_info is not None:
relationship = Relationship(chat=chat_stream, **kwargs)
else:
raise ValueError("必须提供user_id或user_info")
self.relationships[key] = relationship
# 保存到数据库
await self.storage_relationship(relationship)
@ -75,32 +83,86 @@ class RelationshipManager:
return relationship
async def update_relationship_value(self, user_id: int, **kwargs):
async def update_relationship_value(self,
chat_stream:ChatStream,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象可选
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
# 检查是否在内存中已存在
relationship = self.relationships.get(user_id)
relationship = self.relationships.get(key)
if relationship:
for key, value in kwargs.items():
if key == 'relationship_value':
for k, value in kwargs.items():
if k == 'relationship_value':
relationship.relationship_value += value
await self.storage_relationship(relationship)
relationship.saved = True
return relationship
else:
logger.warning(f"用户 {user_id} 不存在,无法更新")
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 不存在,无法更新")
return None
def get_relationship(self, user_id: int) -> Optional[Relationship]:
"""获取用户关系对象"""
if user_id in self.relationships:
return self.relationships[user_id]
def get_relationship(self,
chat_stream:ChatStream) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象可选
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
platform = chat_stream.user_info.platform or 'qq'
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
else:
return 0
async def load_relationship(self, data: dict) -> Relationship:
"""从数据库加载或创建新的关系对象"""
rela = Relationship(user_id=data['user_id'], data=data)
# 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
rela.saved = True
self.relationships[rela.user_id] = rela
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela
async def load_all_relationships(self):
@ -117,10 +179,8 @@ class RelationshipManager:
all_relationships = db.db.relationships.find({})
# 依次加载每条记录
for data in all_relationships:
user_id = data['user_id']
relationship = await self.load_relationship(data)
self.relationships[user_id] = relationship
logger.debug(f"已加载 {len(self.relationships)} 条关系记录")
await self.load_relationship(data)
logger.debug(f"[关系管理] 已加载 {len(self.relationships)} 条关系记录")
while True:
logger.debug("正在自动保存关系")
@ -130,16 +190,15 @@ class RelationshipManager:
async def _save_all_relationships(self):
"""将所有关系数据保存到数据库"""
# 保存所有关系数据
for userid, relationship in self.relationships.items():
for (userid, platform), relationship in self.relationships.items():
if not relationship.saved:
relationship.saved = True
await self.storage_relationship(relationship)
async def storage_relationship(self, relationship: Relationship):
"""
将关系记录存储到数据库中
"""
"""将关系记录存储到数据库中"""
user_id = relationship.user_id
platform = relationship.platform
nickname = relationship.nickname
relationship_value = relationship.relationship_value
gender = relationship.gender
@ -148,8 +207,9 @@ class RelationshipManager:
db = Database.get_instance()
db.db.relationships.update_one(
{'user_id': user_id},
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,
'nickname': nickname,
'relationship_value': relationship_value,
'gender': gender,
@ -159,12 +219,36 @@ class RelationshipManager:
upsert=True
)
def get_name(self, user_id: int) -> str:
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象可选
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 确保user_id是整数类型
user_id = int(user_id)
if user_id in self.relationships:
return self.relationships[user_id].nickname
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
else:
return "某人"

View File

@ -1,7 +1,8 @@
from typing import Optional
from typing import Optional, Union
from ...common.database import Database
from .message import Message
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger
@ -9,42 +10,21 @@ class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库"""
try:
if not message.is_emoji:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
message_data = {
"message_id": message.message_info.message_id,
"time": message.message_info.time,
"chat_id":chat_stream.stream_id,
"chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text,
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
else:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
"processed_plain_text": '[表情包]',
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
self.db.db.messages.insert_one(message_data)
except Exception:
logger.exception(f"存储消息失败")
logger.exception("存储消息失败")
# 如果需要其他存储相关的函数,可以在这里添加

View File

@ -12,32 +12,15 @@ from loguru import logger
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from .message import Message
from .message import MessageRecv,Message
from .message_base import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
driver = get_driver()
config = driver.config
def combine_messages(messages: List[Message]) -> str:
"""将消息列表组合成格式化的字符串
Args:
messages: Message对象列表
Returns:
str: 格式化后的消息字符串
"""
result = ""
for message in messages:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
name = message.user_nickname or f"用户{message.user_id}"
content = message.processed_plain_text or message.plain_text
result += f"[{time_str}] {name}: {content}\n"
return result
def db_message_to_str(message_dict: Dict) -> str:
logger.debug(f"message_dict: {message_dict}")
@ -53,20 +36,15 @@ def db_message_to_str(message_dict: Dict) -> str:
return result
def is_mentioned_bot_in_message(message: Message) -> bool:
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
for keyword in keywords:
if keyword in message.processed_plain_text:
return True
return False
def is_mentioned_bot_in_txt(message: str) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
if keyword in message:
for nickname in nicknames:
if nickname in message.processed_plain_text:
return True
return False
@ -99,46 +77,45 @@ def calculate_information_content(text):
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
"""从数据库中获取最接近指定时间戳的聊天记录
Args:
db: 数据库实例
length: 要获取的消息数量
timestamp: 时间戳
Returns:
list: 消息记录字典列表每个字典包含消息内容和时间信息
list: 消息记录列表每个记录包含时间和文本信息
"""
chat_records = []
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4:
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id']
# 获取该时间戳之后的length条消息且groupid相同
records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
chat_id = closest_record['chat_id'] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(db.db.messages.find(
{
"time": {"$gt": closest_time},
"chat_id": chat_id # 添加chat_id过滤
}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
for record in records:
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
print("消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
# 添加到记录列表中
chat_records.append({
'text': record["detailed_plain_text"],
# 转换记录格式
formatted_records = []
for record in chat_records:
formatted_records.append({
'time': record["time"],
'group_id': record["group_id"]
'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
})
return chat_records
return formatted_records
return []
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
@ -152,35 +129,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
{"chat_id": chat_id},
).sort("time", -1).limit(limit))
if not recent_messages:
return []
# 转换为 Message对象列表
from .message import Message
message_objects = []
for msg_data in recent_messages:
try:
chat_info=msg_data.get("chat_info",{})
chat_stream=ChatStream.from_dict(chat_info)
user_info=msg_data.get("user_info",{})
user_info=UserInfo.from_dict(user_info)
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
chat_stream=chat_stream,
time=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
detailed_plain_text=msg_data.get("detailed_plain_text", "")
)
await msg.initialize()
message_objects.append(msg)
except KeyError:
logger.warning("数据库中存在无效的消息")
@ -191,13 +161,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{"chat_id": chat_stream_id},
{
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"chat_id":1,
"chat_info":1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}

View File

@ -1,296 +1,353 @@
import base64
import io
import os
import time
import zlib # 用于 CRC32
import aiohttp
import hashlib
from typing import Optional, Union
from loguru import logger
from nonebot import get_driver
from PIL import Image
from ...common.database import Database
from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
"""
压缩base64格式的图片到指定大小单位KB并在数据库中记录图片信息
Args:
base64_data: base64编码的图片数据
max_size: 最大文件大小KB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
# 确保图片目录存在
images_dir = "data/images"
os.makedirs(images_dir, exist_ok=True)
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
# 连接数据库
db = Database(
host=config.mongodb_host,
port=int(config.mongodb_port),
db_name=config.database_name,
username=config.mongodb_username,
password=config.mongodb_password,
auth_source=config.mongodb_auth_source
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
image_hash: 图片哈希值
description_type: 描述类型 ('emoji' 'image')
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
return result['description'] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
image_hash: 图片哈希值
description: 描述文本
description_type: 描述类型 ('emoji' 'image')
"""
self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
'description': description,
'timestamp': int(time.time())
}
},
upsert=True
)
# 检查是否已存在相同哈希值的图片
collection = db.db['images']
existing_image = collection.find_one({'hash': hash_value})
if existing_image:
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 如果是动图,直接返回原图
if getattr(img, 'is_animated', False):
return base64_data
# 计算当前大小KB
current_size = len(image_data) / 1024
# 如果已经小于目标大小,直接使用原图
if current_size <= max_size:
compressed_data = image_data
else:
# 压缩逻辑
# 先缩放到50%
new_width = int(img.width * 0.5)
new_height = int(img.height * 0.5)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 如果缩放后的最大边长仍然大于400继续缩放
max_dimension = 400
max_current = max(new_width, new_height)
if max_current > max_dimension:
ratio = max_dimension / max_current
new_width = int(new_width * ratio)
new_height = int(new_height * ratio)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 转换为RGB模式去除透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# 使用固定质量参数压缩
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
compressed_data = output.getvalue()
# 生成文件名(使用时间戳和哈希值确保唯一性)
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(images_dir, filename)
# 保存文件
with open(image_path, "wb") as f:
f.write(compressed_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
async def save_image(self,
image_data: Union[str, bytes],
url: str = None,
description: str = None,
is_base64: bool = False) -> Optional[str]:
"""保存图像
Args:
image_data: 图像数据(base64字符串或字节)
url: 图像URL
description: 图像描述
is_base64: image_data是否为base64格式
Returns:
str: 保存后的文件路径,失败返回None
"""
try:
# 准备数据库记录
image_record = {
'filename': filename,
'path': image_path,
'size': len(compressed_data) / 1024,
'timestamp': timestamp,
'width': img.width,
'height': img.height,
'description': '',
'tags': [],
'type': 'image',
'hash': hash_value
}
# 保存记录
collection.insert_one(image_record)
print("\033[1;32m[成功]\033[0m 保存图片记录到数据库")
except Exception as db_error:
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
# 将压缩后的数据转换为base64
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
return compressed_base64
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
import traceback
print(traceback.format_exc())
return base64_data
def storage_emoji(image_data: bytes) -> bytes:
"""
存储表情包到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
if not global_config.EMOJI_SAVE:
return image_data
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
emoji_dir = "data/emoji"
os.makedirs(emoji_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(emoji_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
emoji_path = os.path.join(emoji_dir, filename)
# 直接保存原始文件
with open(emoji_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
return image_data
def storage_image(image_data: bytes) -> bytes:
"""
存储图片到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
image_dir = "data/image"
os.makedirs(image_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(image_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(image_dir, filename)
# 直接保存原始文件
with open(image_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
# 转换为字节格式
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return None
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return None
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
# 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest()
return base64.b64encode(compressed_data).decode('utf-8')
# 查重
existing = self.db.db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'url': url,
'description': description,
'timestamp': timestamp
}
self.db.db.images.insert_one(image_doc)
return file_path
except Exception as e:
logger.error(f"保存图像失败: {str(e)}")
return None
async def get_image_by_url(self, url: str) -> Optional[str]:
"""根据URL获取图像路径(带查重)
Args:
url: 图像URL
Returns:
str: 本地文件路径,不存在返回None
"""
try:
# 先查找是否已存在
existing = self.db.db.images.find_one({'url': url})
if existing:
return existing['path']
# 下载图像
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
image_bytes = await resp.read()
return await self.save_image(image_bytes, url=url)
return None
except Exception as e:
logger.error(f"获取图像失败: {str(e)}")
return None
async def get_base64_by_url(self, url: str) -> Optional[str]:
"""根据URL获取base64(带查重)
Args:
url: 图像URL
Returns:
str: base64字符串,失败返回None
"""
try:
image_path = await self.get_image_by_url(url)
if not image_path:
return None
with open(image_path, 'rb') as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode('utf-8')
except Exception as e:
logger.error(f"获取base64失败: {str(e)}")
return None
def check_url_exists(self, url: str) -> bool:
"""检查URL是否已存在
Args:
url: 图像URL
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
Args:
image_data: 图像数据(base64或字节)
is_base64: 是否为base64格式
Returns:
bool: 是否存在
"""
try:
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return False
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
return False
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description:
logger.info(f"缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'emoji',
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description:
return f"[图片:{cached_description}]"
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR,'image', filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'image',
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, 'image')
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
# 创建全局单例
image_manager = ImageManager()
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data
def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码

View File

@ -1,110 +1,104 @@
import asyncio
from loguru import logger
from typing import Dict
from .config import global_config
from .chat_stream import ChatStream
class WillingManager:
def __init__(self):
self.group_reply_willing = {} # 存储每个群的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self._decay_task = None
self._started = False
self.min_reply_willing = 0.15
self.attenuation_coefficient = 0.75
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(5)
for group_id in self.group_reply_willing:
self.group_reply_willing[group_id] = max(
self.min_reply_willing,
self.group_reply_willing[group_id] * self.attenuation_coefficient
)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
def get_willing(self, group_id: int) -> float:
"""获取指定群组的回复意愿"""
return self.group_reply_willing.get(group_id, 0)
def get_willing(self,chat_stream:ChatStream) -> float:
"""获取指定聊天流的回复意愿"""
stream = chat_stream
if stream:
return self.chat_reply_willing.get(stream.stream_id, 0)
return 0
def set_willing(self, group_id: int, willing: float):
"""设置指定群组的回复意愿"""
self.group_reply_willing[group_id] = willing
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config,
user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
async def change_reply_willing_received(self,
chat_stream:ChatStream,
topic: str = None,
is_mentioned_bot: bool = False,
config = None,
is_emoji: bool = False,
interested_rate: float = 0) -> float:
"""改变指定聊天流的回复意愿并返回回复概率"""
# 获取或创建聊天流
stream = chat_stream
chat_id = stream.stream_id
# 若非目标回复群组则直接return
if group_id not in config.talk_allowed_groups:
reply_probability = 0
return reply_probability
current_willing = self.chat_reply_willing.get(chat_id, 0)
current_willing = self.group_reply_willing.get(group_id, 0)
logger.debug(f"[{group_id}]的初始回复意愿: {current_willing}")
if current_willing<0.25:
current_willing+=0.03 #多说点话
# 根据消息类型被cue/表情包)调控
if is_mentioned_bot:
current_willing = min(
3.0,
current_willing + 0.9
)
logger.debug(f"被提及, 当前意愿: {current_willing}")
# print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0:
current_willing += 0.9
print(f"被提及, 当前意愿: {current_willing}")
elif is_mentioned_bot:
current_willing += 0.05
print(f"被重复提及, 当前意愿: {current_willing}")
if is_emoji:
current_willing *= 0.1
logger.debug(f"表情包, 当前意愿: {current_willing}")
print(f"表情包, 当前意愿: {current_willing}")
# 兴趣放大系数,若兴趣 > 0.4则增加回复概率
interested_rate_amplifier = global_config.response_interested_rate_amplifier
logger.debug(f"放大系数_interested_rate: {interested_rate_amplifier}")
interested_rate *= interested_rate_amplifier
print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
if interested_rate > 0.4:
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
current_willing += interested_rate-0.4
current_willing += max(
0.0,
interested_rate - 0.4
)
current_willing *= global_config.response_willing_amplifier #放大回复意愿
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
# 回复意愿系数调控,独立乘区
willing_amplifier = max(
global_config.response_willing_amplifier,
self.min_reply_willing
)
current_willing *= willing_amplifier
logger.debug(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
reply_probability = max((current_willing - 0.45) * 2, 0)
# 回复概率迭代保底0.01回复概率
reply_probability = max(
#(current_willing - 0.45) * 2,
current_willing,
self.min_reply_willing
)
# 降低目标低频群组回复概率
down_frequency_rate = max(
1.0,
global_config.down_frequency_rate
)
if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / down_frequency_rate
# 检查群组权限(如果是群聊)
if chat_stream.group_info:
if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / global_config.down_frequency_rate
reply_probability = min(reply_probability, 1)
if reply_probability < 0:
reply_probability = 0
self.group_reply_willing[group_id] = min(current_willing, 3.0)
logger.debug(f"当前群组{group_id}回复概率:{reply_probability}")
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
return reply_probability
def change_reply_willing_sent(self, group_id: int):
"""开始思考后降低群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
self.group_reply_willing[group_id] = max(0, current_willing - 2)
def change_reply_willing_sent(self, chat_stream:ChatStream):
"""开始思考后降低聊天流的回复意愿"""
stream = chat_stream
if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, group_id: int):
"""发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
def change_reply_willing_after_sent(self,chat_stream:ChatStream):
"""发送消息后提高聊天流的回复意愿"""
stream = chat_stream
if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
if current_willing < 1:
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
async def ensure_started(self):
"""确保衰减任务已启动"""
@ -113,6 +107,5 @@ class WillingManager:
self._decay_task = asyncio.create_task(self._decay_reply_willing())
self._started = True
# 创建全局实例
willing_manager = WillingManager()

View File

@ -0,0 +1,10 @@
from nonebot import get_app
from .api import router
from loguru import logger
# 获取主应用实例并挂载路由
app = get_app()
app.include_router(router, prefix="/api")
# 打印日志方便确认API已注册
logger.success("配置重载API已注册可通过 /api/reload-config 访问")

View File

@ -0,0 +1,17 @@
from fastapi import APIRouter, HTTPException
from src.plugins.chat.config import BotConfig
import os
# 创建APIRouter而不是FastAPI实例
router = APIRouter()
@router.post("/reload-config")
async def reload_config():
try:
bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
global_config = BotConfig.load_config(config_path=bot_config_path)
return {"message": "配置重载成功", "status": "success"}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}")

View File

@ -0,0 +1,3 @@
import requests
response = requests.post("http://localhost:8080/api/reload-config")
print(response.json())

View File

@ -1,198 +0,0 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.dev")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
from src.common.database import Database
# 从环境变量获取配置
Database.initialize(
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "maimai"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin")
)
class KnowledgeLibrary:
def __init__(self):
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self):
"""处理raw_info目录下的所有txt文件"""
for filename in os.listdir(self.raw_info_dir):
if filename.endswith('.txt'):
file_path = os.path.join(self.raw_info_dir, filename)
self.process_single_file(file_path)
def process_single_file(self, file_path: str):
"""处理单个文件"""
try:
# 检查文件是否已处理
if self.db.db.processed_files.find_one({"file_path": file_path}):
print(f"文件已处理过,跳过: {file_path}")
return
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按1024字符分段
segments = [content[i:i+600] for i in range(0, len(content), 600)]
# 处理每个分段
for segment in segments:
if not segment.strip(): # 跳过空段
continue
# 获取embedding
embedding = self.get_embedding(segment)
if not embedding:
continue
# 存储到数据库
doc = {
"content": segment,
"embedding": embedding,
"file_path": file_path,
"segment_length": len(segment)
}
# 使用文本内容的哈希值作为唯一标识
content_hash = hash(segment)
# 更新或插入文档
self.db.db.knowledges.update_one(
{"content_hash": content_hash},
{"$set": doc},
upsert=True
)
# 记录文件已处理
self.db.db.processed_files.insert_one({
"file_path": file_path,
"processed_time": time.time()
})
print(f"成功处理文件: {file_path}")
except Exception as e:
print(f"处理文件 {file_path} 时出错: {str(e)}")
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
# 测试知识库功能
print("开始处理知识库文件...")
knowledge_library.process_files()
# 测试搜索功能
test_query = "麦麦评价一下僕と花"
print(f"\n搜索与'{test_query}'相似的内容:")
results = knowledge_library.search_similar_segments(test_query)
for result in results:
print(f"相似度: {result['similarity']:.4f}")
print(f"内容: {result['content'][:100]}...")
print("-" * 50)

View File

@ -9,7 +9,10 @@ import networkx as nx
from dotenv import load_dotenv
from loguru import logger
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
@ -162,12 +165,13 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
memory_graph = Memory_graph()

View File

@ -3,11 +3,13 @@ import datetime
import math
import random
import time
import os
import jieba
import networkx as nx
from loguru import logger
from nonebot import get_driver
from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config
from ..chat.utils import (
@ -18,33 +20,52 @@ from ..chat.utils import (
)
from ..models.utils_model import LLM_request
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength
# 避免自连接
if concept1 == concept2:
return
current_time = datetime.datetime.now().timestamp()
# 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
# 更新最后修改时间
self.G[concept1][concept2]['last_modified'] = current_time
else:
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1)
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2,
strength=1,
created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间
def add_dot(self, concept, memory):
current_time = datetime.datetime.now().timestamp()
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
self.G.nodes[concept]['memory_items'].append(memory)
# 更新最后修改时间
self.G.nodes[concept]['last_modified'] = current_time
else:
self.G.nodes[concept]['memory_items'] = [memory]
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
if 'created_time' not in self.G.nodes[concept]:
self.G.nodes[concept]['created_time'] = current_time
self.G.nodes[concept]['last_modified'] = current_time
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept,
memory_items=[memory],
created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间
def get_dot(self, concept):
# 检查节点是否存在于图中
@ -191,15 +212,11 @@ class Hippocampus:
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩消息记录为记忆
Args:
messages: 消息记录字典列表每个字典包含text和time字段
compress_rate: 压缩率
Returns:
set: (话题, 记忆) 元组集合
tuple: (压缩记忆集合, 相似主题字典)
"""
if not messages:
return set()
return set(), {}
# 合并消息文本,同时保留时间信息
input_text = ""
@ -222,7 +239,7 @@ class Hippocampus:
time_info += f"是从 {earliest_str}{latest_str} 的对话:\n"
for msg in messages:
input_text += f"{msg['text']}\n"
input_text += f"{msg['detailed_plain_text']}\n"
logger.debug(input_text)
@ -246,12 +263,33 @@ class Hippocampus:
# 等待所有任务完成
compressed_memory = set()
similar_topics_dict = {} # 存储每个话题的相似主题列表
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
# 为每个话题查找相似的已存在主题
existing_topics = list(self.memory_graph.G.nodes())
similar_topics = []
return compressed_memory
for existing_topic in existing_topics:
topic_words = set(jieba.cut(topic))
existing_words = set(jieba.cut(existing_topic))
all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words]
v2 = [1 if word in existing_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= 0.6:
similar_topics.append((existing_topic, similarity))
similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:5]
similar_topics_dict[topic] = similar_topics
return compressed_memory, similar_topics_dict
def calculate_topic_num(self, text, compress_rate):
"""计算文本的话题数量"""
@ -265,33 +303,45 @@ class Hippocampus:
return topic_num
async def operation_build_memory(self, chat_size=20):
# 最近消息获取频率
time_frequency = {'near': 2, 'mid': 4, 'far': 2}
memory_sample = self.get_memory_sample(chat_size, time_frequency)
time_frequency = {'near': 1, 'mid': 4, 'far': 4}
memory_samples = self.get_memory_sample(chat_size, time_frequency)
for i, input_text in enumerate(memory_sample, 1):
# 加载进度可视化
for i, messages in enumerate(memory_samples, 1):
all_topics = []
progress = (i / len(memory_sample)) * 100
# 加载进度可视化
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
filled_length = int(bar_length * i // len(memory_samples))
bar = '' * filled_length + '-' * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set()
compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)}")
compress_rate = global_config.memory_compress_rate
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
current_time = datetime.datetime.now().timestamp()
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题
all_topics.append(topic)
# 连接相似的已存在主题
if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic]
for similar_topic, similarity in similar_topics:
if topic != similar_topic:
strength = int(similarity * 10)
logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})")
self.memory_graph.G.add_edge(topic, similar_topic,
strength=strength,
created_time=current_time,
last_modified=current_time)
# 连接同批次的相关话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
logger.info(f"连接节点: {all_topics[i]}{all_topics[j]}")
logger.info(f"连接同批次节点: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
@ -302,7 +352,7 @@ class Hippocampus:
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式方便查找
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node['concept']: node for node in db_nodes}
# 检查并更新节点
@ -314,12 +364,18 @@ class Hippocampus:
# 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items)
# 获取时间信息
created_time = data.get('created_time', datetime.datetime.now().timestamp())
last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
# 数据库中缺少的节点,添加
node_data = {
'concept': concept,
'memory_items': memory_items,
'hash': memory_hash
'hash': memory_hash,
'created_time': created_time,
'last_modified': last_modified
}
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
else:
@ -327,25 +383,21 @@ class Hippocampus:
db_node = db_nodes_dict[concept]
db_hash = db_node.get('hash', None)
# 如果特征值不同则更新节点
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
'hash': memory_hash
'hash': memory_hash,
'created_time': created_time,
'last_modified': last_modified
}}
)
# 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
db_edge_dict = {}
@ -357,10 +409,14 @@ class Hippocampus:
}
# 检查并更新边
for source, target in memory_edges:
for source, target, data in memory_edges:
edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target)
strength = self.memory_graph.G[source][target].get('strength', 1)
strength = data.get('strength', 1)
# 获取边的时间信息
created_time = data.get('created_time', datetime.datetime.now().timestamp())
last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
if edge_key not in db_edge_dict:
# 添加新边
@ -368,7 +424,9 @@ class Hippocampus:
'source': source,
'target': target,
'strength': strength,
'hash': edge_hash
'hash': edge_hash,
'created_time': created_time,
'last_modified': last_modified
}
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
else:
@ -378,88 +436,168 @@ class Hippocampus:
{'source': source, 'target': target},
{'$set': {
'hash': edge_hash,
'strength': strength
'strength': strength,
'created_time': created_time,
'last_modified': last_modified
}}
)
# 删除多余的边
memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
self.memory_graph.db.db.graph_data.edges.delete_one({
'source': source,
'target': target
})
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
current_time = datetime.datetime.now().timestamp()
need_update = False
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = self.memory_graph.db.db.graph_data.nodes.find()
nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在
if 'created_time' not in node or 'last_modified' not in node:
need_update = True
# 更新数据库中的节点
update_data = {}
if 'created_time' not in node:
update_data['created_time'] = current_time
if 'last_modified' not in node:
update_data['last_modified'] = current_time
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': update_data}
)
logger.info(f"为节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.get('created_time', current_time)
last_modified = node.get('last_modified', current_time)
# 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items)
self.memory_graph.G.add_node(concept,
memory_items=memory_items,
created_time=created_time,
last_modified=last_modified)
# 从数据库加载所有边
edges = self.memory_graph.db.db.graph_data.edges.find()
edges = list(self.memory_graph.db.db.graph_data.edges.find())
for edge in edges:
source = edge['source']
target = edge['target']
strength = edge.get('strength', 1) # 获取 strength默认为 1
strength = edge.get('strength', 1)
# 检查时间字段是否存在
if 'created_time' not in edge or 'last_modified' not in edge:
need_update = True
# 更新数据库中的边
update_data = {}
if 'created_time' not in edge:
update_data['created_time'] = current_time
if 'last_modified' not in edge:
update_data['last_modified'] = current_time
self.memory_graph.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': update_data}
)
logger.info(f"为边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.get('created_time', current_time)
last_modified = edge.get('last_modified', current_time)
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength)
self.memory_graph.G.add_edge(source, target,
strength=strength,
created_time=created_time,
last_modified=last_modified)
if need_update:
logger.success("已为缺失的时间字段进行补充")
async def operation_forget_topic(self, percentage=0.1):
"""随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘"""
# 获取所有节点
"""随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘"""
# 检查数据库是否为空
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
all_edges = list(self.memory_graph.G.edges())
forgotten_nodes = []
if not all_nodes and not all_edges:
logger.info("记忆图为空,无需进行遗忘操作")
return
check_nodes_count = max(1, int(len(all_nodes) * percentage))
check_edges_count = max(1, int(len(all_edges) * percentage))
nodes_to_check = random.sample(all_nodes, check_nodes_count)
edges_to_check = random.sample(all_edges, check_edges_count)
edge_changes = {'weakened': 0, 'removed': 0}
node_changes = {'reduced': 0, 'removed': 0}
current_time = datetime.datetime.now().timestamp()
# 检查并遗忘连接
logger.info("开始检查连接...")
for source, target in edges_to_check:
edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get('last_modified')
# print(source,target)
# print(f"float(last_modified):{float(last_modified)}" )
# print(f"current_time:{current_time}")
# print(f"current_time - last_modified:{current_time - last_modified}")
if current_time - last_modified > 3600*global_config.memory_forget_time: # test
current_strength = edge_data.get('strength', 1)
new_strength = current_strength - 1
if new_strength <= 0:
self.memory_graph.G.remove_edge(source, target)
edge_changes['removed'] += 1
logger.info(f"\033[1;31m[连接移除]\033[0m {source} - {target}")
else:
edge_data['strength'] = new_strength
edge_data['last_modified'] = current_time
edge_changes['weakened'] += 1
logger.info(f"\033[1;34m[连接减弱]\033[0m {source} - {target} (强度: {current_strength} -> {new_strength})")
# 检查并遗忘话题
logger.info("开始检查节点...")
for node in nodes_to_check:
# 获取节点的连接数
connections = self.memory_graph.G.degree(node)
node_data = self.memory_graph.G.nodes[node]
last_modified = node_data.get('last_modified', current_time)
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
if current_time - last_modified > 3600*24: # test
memory_items = node_data.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 检查连接强度
weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度
for neighbor in self.memory_graph.G.neighbors(node):
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
if strength > 2:
weak_connections = False
break
if memory_items:
current_count = len(memory_items)
removed_item = random.choice(memory_items)
memory_items.remove(removed_item)
# 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.memory_graph.forget_topic(node)
if removed_item:
forgotten_nodes.append((node, removed_item))
logger.debug(f"遗忘节点 {node} 的记忆: {removed_item}")
if memory_items:
self.memory_graph.G.nodes[node]['memory_items'] = memory_items
self.memory_graph.G.nodes[node]['last_modified'] = current_time
node_changes['reduced'] += 1
logger.info(f"\033[1;33m[记忆减少]\033[0m {node} (记忆数量: {current_count} -> {len(memory_items)})")
else:
self.memory_graph.G.remove_node(node)
node_changes['removed'] += 1
logger.info(f"\033[1;31m[节点移除]\033[0m {node}")
# 同步到数据库
if forgotten_nodes:
if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()):
self.sync_memory_to_db()
logger.debug(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
logger.info("\n遗忘操作统计:")
logger.info(f"连接变化: {edge_changes['weakened']} 个减弱, {edge_changes['removed']} 个移除")
logger.info(f"节点变化: {node_changes['reduced']} 个减少记忆, {node_changes['removed']} 个移除")
else:
logger.debug("本次检查没有节点满足遗忘条件")
logger.info("\n本次检查没有节点或连接满足遗忘条件")
async def merge_memory(self, topic):
"""
@ -486,7 +624,7 @@ class Hippocampus:
logger.debug(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(selected_memories, 0.1)
compressed_memories, _ = await self.memory_compress(selected_memories, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
@ -749,21 +887,21 @@ def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
from nonebot import get_driver
driver = get_driver()
config = driver.config
start_time = time.time()
Database.initialize(
host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME,
username=config.MONGODB_USERNAME,
password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
# 创建记忆图
memory_graph = Memory_graph()

View File

@ -10,13 +10,15 @@ from pathlib import Path
import matplotlib.pyplot as plt
import networkx as nx
import pymongo
from dotenv import load_dotenv
from loguru import logger
import jieba
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database
from src.plugins.memory_system.offline_llm import LLMModel
@ -35,45 +37,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
@ -941,7 +904,15 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
start_time = time.time()
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
@ -1008,9 +979,6 @@ async def main():
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

File diff suppressed because it is too large Load Diff

View File

@ -7,10 +7,11 @@ from typing import Tuple, Union
import aiohttp
from loguru import logger
from nonebot import get_driver
import base64
from PIL import Image
import io
from ...common.database import Database
from ..chat.config import global_config
from ..chat.utils_image import compress_base64_image_by_scale
driver = get_driver()
config = driver.config
@ -44,8 +45,8 @@ class LLM_request:
self.db.db.llm_usage.create_index([("model_name", 1)])
self.db.db.llm_usage.create_index([("user_id", 1)])
self.db.db.llm_usage.create_index([("request_type", 1)])
except Exception as e:
logger.error(f"创建数据库索引失败")
except Exception:
logger.error("创建数据库索引失败")
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
user_id: str = "system", request_type: str = "chat",
@ -80,7 +81,7 @@ class LLM_request:
f"总计: {total_tokens}"
)
except Exception:
logger.error(f"记录token使用情况失败")
logger.error("记录token使用情况失败")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
@ -183,18 +184,21 @@ class LLM_request:
elif response.status in policy["abort_codes"]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
if response.status == 403:
# 尝试降级Pro模型
#只针对硅基流动的V3和R1进行降级处理
if self.model_name.startswith(
"Pro/") and self.base_url == "https://api.siliconflow.cn/v1/":
"Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get(
'name') == old_model_name:
if global_config.llm_normal.get('name') == old_model_name:
global_config.llm_normal['name'] = self.model_name
logger.warning(f"已将全局配置中的 llm_normal 模型降级")
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get('name') == old_model_name:
global_config.llm_reasoning['name'] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
# 更新payload中的模型名
if payload and 'model' in payload:
@ -210,6 +214,7 @@ class LLM_request:
# 将流式输出转化为非流式输出
if stream_mode:
flag_delta_content_finished = False
accumulated_content = ""
async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip()
@ -221,13 +226,25 @@ class LLM_request:
break
try:
chunk = json.loads(data_str)
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
if flag_delta_content_finished:
usage = chunk.get("usage", None) # 获取tokn用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0]["finish_reason"]
if finish_reason == "stop":
usage = chunk.get("usage", None)
if usage:
break
# 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk
flag_delta_content_finished = True
except Exception:
logger.exception(f"解析流式输出错")
logger.exception("解析流式输出错")
content = accumulated_content
reasoning_content = ""
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
@ -236,7 +253,7 @@ class LLM_request:
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]}
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}], "usage": usage}
return response_handler(result) if response_handler else self._default_response_handler(
result, user_id, request_type, endpoint)
else:
@ -355,7 +372,7 @@ class LLM_request:
"""构建请求头"""
if no_key:
return {
"Authorization": f"Bearer **********",
"Authorization": "Bearer **********",
"Content-Type": "application/json"
}
else:
@ -432,3 +449,78 @@ class LLM_request:
response_handler=embedding_handler
)
return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data

View File

@ -1,3 +1,4 @@
import os
import datetime
import json
from typing import Dict, Union
@ -14,15 +15,15 @@ driver = get_driver()
config = driver.config
Database.initialize(
host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME,
username=config.MONGODB_USERNAME,
password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class ScheduleGenerator:
def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型
@ -68,7 +69,7 @@ class ScheduleGenerator:
1. 早上的学习和工作安排
2. 下午的活动和任务
3. 晚上的计划和休息时间
请按照时间顺序列出具体时间点和对应的活动用一个时间点而不是时间段来表示时间用JSON格式返回日程表仅返回内容不要返回注释时间采用24小时制格式为{"时间": "活动","时间": "活动",...}"""
请按照时间顺序列出具体时间点和对应的活动用一个时间点而不是时间段来表示时间用JSON格式返回日程表仅返回内容不要返回注释不要添加任何markdown或代码块样式时间采用24小时制格式为{"时间": "活动","时间": "活动",...}"""
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
@ -91,7 +92,7 @@ class ScheduleGenerator:
try:
schedule_dict = json.loads(schedule_text)
return schedule_dict
except json.JSONDecodeError as e:
except json.JSONDecodeError:
logger.exception("解析日程失败: {}".format(schedule_text))
return False

View File

@ -155,7 +155,7 @@ class LLMStatistics:
all_stats = self._collect_all_statistics()
self._save_statistics(all_stats)
except Exception:
logger.exception(f"统计数据处理失败")
logger.exception("统计数据处理失败")
# 等待1分钟
for _ in range(60):

View File

@ -0,0 +1,383 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import Database
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.prod")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
class KnowledgeLibrary:
def __init__(self):
# 初始化数据库连接
if Database._instance is None:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def read_file(self, file_path: str) -> str:
"""读取文件内容"""
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def split_content(self, content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,保持段落完整性
Args:
content: 要分割的文本内容
max_length: 每个块的最大长度
Returns:
list: 分割后的文本块列表
"""
# 首先按段落分割
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
chunks = []
current_chunk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# 如果单个段落就超过最大长度
if para_length > max_length:
# 如果当前chunk不为空先保存
if current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_length = 0
# 将长段落按句子分割
sentences = [s.strip() for s in para.replace('', '\n').replace('', '\n').replace('', '\n').split('\n') if s.strip()]
temp_chunk = []
temp_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if sentence_length > max_length:
# 如果单个句子超长,强制按长度分割
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
temp_chunk = []
temp_length = 0
for i in range(0, len(sentence), max_length):
chunks.append(sentence[i:i + max_length])
elif temp_length + sentence_length + 1 <= max_length:
temp_chunk.append(sentence)
temp_length += sentence_length + 1
else:
chunks.append('\n'.join(temp_chunk))
temp_chunk = [sentence]
temp_length = sentence_length
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
# 如果当前段落加上现有chunk不超过最大长度
elif current_length + para_length + 1 <= max_length:
current_chunk.append(para)
current_length += para_length + 1
else:
# 保存当前chunk并开始新的chunk
chunks.append('\n'.join(current_chunk))
current_chunk = [para]
current_length = para_length
# 添加最后一个chunk
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self, knowledge_length:int=512):
"""处理raw_info目录下的所有txt文件"""
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith('.txt')]
if not txt_files:
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
return
total_stats = {
"processed_files": 0,
"total_chunks": 0,
"failed_files": [],
"skipped_files": []
}
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
for filename in tqdm(txt_files, desc="处理文件进度"):
file_path = os.path.join(self.raw_info_dir, filename)
result = self.process_single_file(file_path, knowledge_length)
self._update_stats(total_stats, result, filename)
self._display_processing_results(total_stats)
def process_single_file(self, file_path: str, knowledge_length: int = 512):
"""处理单个文件"""
result = {
"status": "success",
"chunks_processed": 0,
"error": None
}
try:
current_hash = self.calculate_file_hash(file_path)
processed_record = self.db.db.processed_files.find_one({"file_path": file_path})
if processed_record:
if processed_record.get("hash") == current_hash:
if knowledge_length in processed_record.get("split_by", []):
result["status"] = "skipped"
return result
content = self.read_file(file_path)
chunks = self.split_content(content, knowledge_length)
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
embedding = self.get_embedding(chunk)
if embedding:
knowledge = {
"content": chunk,
"embedding": embedding,
"source_file": file_path,
"split_length": knowledge_length,
"created_at": datetime.now()
}
self.db.db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by:
split_by.append(knowledge_length)
self.db.db.processed_files.update_one(
{"file_path": file_path},
{
"$set": {
"hash": current_hash,
"last_processed": datetime.now(),
"split_by": split_by
}
},
upsert=True
)
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
return result
def _update_stats(self, total_stats, result, filename):
"""更新总体统计信息"""
if result["status"] == "success":
total_stats["processed_files"] += 1
total_stats["total_chunks"] += result["chunks_processed"]
elif result["status"] == "failed":
total_stats["failed_files"].append((filename, result["error"]))
elif result["status"] == "skipped":
total_stats["skipped_files"].append(filename)
def _display_processing_results(self, stats):
"""显示处理结果统计"""
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
table = Table(show_header=True, header_style="bold magenta")
table.add_column("统计项", style="dim")
table.add_column("数值")
table.add_row("成功处理文件数", str(stats["processed_files"]))
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
table.add_row("失败的文件数", str(len(stats["failed_files"])))
self.console.print(table)
if stats["failed_files"]:
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
for filename, error in stats["failed_files"]:
self.console.print(f"[red]- {filename}: {error}[/red]")
if stats["skipped_files"]:
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]")
def calculate_file_hash(self, file_path):
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
console = Console()
console.print("[bold green]知识库处理工具[/bold green]")
while True:
console.print("\n请选择要执行的操作:")
console.print("[1] 麦麦开始学习")
console.print("[2] 麦麦全部忘光光(仅知识)")
console.print("[q] 退出程序")
choice = input("\n请输入选项: ").strip()
if choice.lower() == 'q':
console.print("[yellow]程序退出[/yellow]")
sys.exit(0)
elif choice == '2':
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == 'y':
knowledge_library.db.db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]")
continue
elif choice == '1':
if not os.path.exists(knowledge_library.raw_info_dir):
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
# 询问分割长度
while True:
try:
length_input = input("请输入知识分割长度默认512输入q退出回车使用默认值: ").strip()
if length_input.lower() == 'q':
break
if not length_input: # 如果直接回车,使用默认值
knowledge_length = 512
break
knowledge_length = int(length_input)
if knowledge_length <= 0:
print("分割长度必须大于0请重新输入")
continue
break
except ValueError:
print("请输入有效的数字")
continue
if length_input.lower() == 'q':
continue
# 测试知识库功能
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
knowledge_library.process_files(knowledge_length=knowledge_length)
else:
console.print("[red]无效的选项,请重新选择[/red]")
continue

View File

@ -5,13 +5,18 @@ PORT=8080
PLUGINS=["src2.plugins.chat"]
# 默认配置
MONGODB_HOST=127.0.0.1 # 如果工作在Docker下请改成 MONGODB_HOST=mongodb
# 如果工作在Docker下请改成 MONGODB_HOST=mongodb
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 默认空值
MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值
# 也可以使用 URI 连接数据库(优先级比上面的高)
# MONGODB_URI=mongodb://127.0.0.1:27017/MegBot
# MongoDB 认证信息,若需要认证,请取消注释以下三行并填写正确的信息
# MONGODB_USERNAME=user
# MONGODB_PASSWORD=password
# MONGODB_AUTH_SOURCE=admin
#key and url
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1

View File

@ -1,5 +1,5 @@
[inner]
version = "0.0.4"
version = "0.0.8"
#如果你想要修改配置文件请在修改后将version的值进行变更
#如果新增项目请在BotConfig类下新增相应的变量
@ -15,6 +15,7 @@ version = "0.0.4"
[bot]
qq = 123
nickname = "麦麦"
alias_names = ["小麦", "阿麦"]
[personality]
prompt_personality = [
@ -40,6 +41,13 @@ ban_words = [
# "403","张三"
]
ban_msgs_regex = [
# 需要过滤的消息原始消息匹配的正则表达式匹配到的消息将被过滤支持CQ码若不了解正则表达式请勿修改
#"https?://[^\\s]+", # 匹配https链接
#"\\d{4}-\\d{2}-\\d{2}", # 匹配日期
# "\\[CQ:at,qq=\\d+\\]" # 匹配@
]
[emoji]
check_interval = 120 # 检查表情包的时间间隔
register_interval = 10 # 注册表情包的时间间隔
@ -57,8 +65,13 @@ model_r1_distill_probability = 0.1 # 麦麦回答时选择次要回复模型3
max_response_length = 1024 # 麦麦回答的最大token数
[memory]
build_memory_interval = 300 # 记忆构建间隔 单位秒
forget_memory_interval = 300 # 记忆遗忘间隔 单位秒
build_memory_interval = 600 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
forget_memory_interval = 600 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时
memory_forget_percentage = 0.01 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
memory_ban_words = [ #不希望记忆的词
# "403","张三"
@ -92,6 +105,8 @@ word_replace_rate=0.006 # 整词替换概率
[others]
enable_advance_output = true # 是否启用高级输出
enable_kuuki_read = true # 是否启用读空气功能
enable_debug_output = false # 是否启用调试输出
enable_friend_chat = false # 是否启用好友聊天
[groups]
talk_allowed = [

View File

@ -0,0 +1,4 @@
更新版本后建议删除数据库messages中所有内容不然会出现报错
该操作不会影响你的记忆
如果显示配置文件版本过低运行根目录的bat

View File

@ -0,0 +1,45 @@
@echo off
setlocal enabledelayedexpansion
chcp 65001
cd /d %~dp0
echo =====================================
echo 选择Python环境:
echo 1 - venv (推荐)
echo 2 - conda
echo =====================================
choice /c 12 /n /m "输入数字(1或2): "
if errorlevel 2 (
echo =====================================
set "CONDA_ENV="
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
:: 检查输入是否为空
if "!CONDA_ENV!"=="" (
echo 错误:环境名称不能为空
pause
exit /b 1
)
call conda activate !CONDA_ENV!
if errorlevel 1 (
echo 激活 conda 环境失败
pause
exit /b 1
)
echo Conda 环境 "!CONDA_ENV!" 激活成功
python config/auto_update.py
) else (
if exist "venv\Scripts\python.exe" (
venv\Scripts\python config/auto_update.py
) else (
echo =====================================
echo 错误: venv环境不存在请先创建虚拟环境
pause
exit /b 1
)
)
endlocal
pause

View File

@ -0,0 +1,45 @@
@echo off
setlocal enabledelayedexpansion
chcp 65001
cd /d %~dp0
echo =====================================
echo 选择Python环境:
echo 1 - venv (推荐)
echo 2 - conda
echo =====================================
choice /c 12 /n /m "输入数字(1或2): "
if errorlevel 2 (
echo =====================================
set "CONDA_ENV="
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
:: 检查输入是否为空
if "!CONDA_ENV!"=="" (
echo 错误:环境名称不能为空
pause
exit /b 1
)
call conda activate !CONDA_ENV!
if errorlevel 1 (
echo 激活 conda 环境失败
pause
exit /b 1
)
echo Conda 环境 "!CONDA_ENV!" 激活成功
python src/plugins/zhishi/knowledge_library.py
) else (
if exist "venv\Scripts\python.exe" (
venv\Scripts\python src/plugins/zhishi/knowledge_library.py
) else (
echo =====================================
echo 错误: venv环境不存在请先创建虚拟环境
pause
exit /b 1
)
)
endlocal
pause