From 5136c617ceccf1823e6666bfceded0821cf4cbc1 Mon Sep 17 00:00:00 2001
From: A0000Xz <122650088+A0000Xz@users.noreply.github.com>
Date: Wed, 28 May 2025 17:59:45 +0800
Subject: [PATCH 01/13] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=95=B0?=
=?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=81=E7=A7=BB=E5=B7=A5=E5=85=B7=E7=9A=84?=
=?UTF-8?q?=E8=84=9A=E6=9C=AC?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
scripts/mongodb_to_sqlite.py | 692 +++++++++++++++++++++++++++++++++++
1 file changed, 692 insertions(+)
create mode 100644 scripts/mongodb_to_sqlite.py
diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py
new file mode 100644
index 00000000..83e08679
--- /dev/null
+++ b/scripts/mongodb_to_sqlite.py
@@ -0,0 +1,692 @@
+import os
+import json
+import sys # 新增系统模块导入
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+from typing import Dict, Any, List, Optional, Type
+from dataclasses import dataclass, field
+from datetime import datetime
+from pymongo import MongoClient
+from pymongo.errors import ConnectionFailure
+from peewee import Model, Field, IntegrityError
+
+from src.common.database.database import db
+from src.common.database.database_model import (
+ ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions,
+ OnlineTime, PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges
+)
+from src.common.logger_manager import get_logger
+
+logger = get_logger("mongodb_to_sqlite")
+
+
+@dataclass
+class MigrationConfig:
+ """迁移配置类"""
+ mongo_collection: str
+ target_model: Type[Model]
+ field_mapping: Dict[str, str]
+ batch_size: int = 500
+ enable_validation: bool = True
+ skip_duplicates: bool = True
+ unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
+
+
+@dataclass
+class MigrationStats:
+ """迁移统计信息"""
+ total_documents: int = 0
+ processed_count: int = 0
+ success_count: int = 0
+ error_count: int = 0
+ skipped_count: int = 0
+ duplicate_count: int = 0
+ errors: List[Dict[str, Any]] = field(default_factory=list)
+
+ def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
+ """添加错误记录"""
+ self.errors.append({
+ 'doc_id': str(doc_id),
+ 'error': error,
+ 'timestamp': datetime.now().isoformat(),
+ 'doc_data': doc_data
+ })
+ self.error_count += 1
+
+
+class MongoToSQLiteMigrator:
+ """MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
+
+ def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
+ self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
+ self.mongo_uri = mongo_uri or self._build_mongo_uri()
+ self.mongo_client: Optional[MongoClient] = None
+ self.mongo_db = None
+
+ # 迁移配置
+ self.migration_configs = self._initialize_migration_configs()
+
+ def _build_mongo_uri(self) -> str:
+ """构建MongoDB连接URI"""
+ if mongo_uri := os.getenv("MONGODB_URI"):
+ return mongo_uri
+
+ user = os.getenv('MONGODB_USER')
+ password = os.getenv('MONGODB_PASS')
+ host = os.getenv('MONGODB_HOST', 'localhost')
+ port = os.getenv('MONGODB_PORT', '27017')
+ auth_source = os.getenv('MONGODB_AUTH_SOURCE', 'admin')
+
+ if user and password:
+ return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
+ else:
+ return f"mongodb://{host}:{port}/{self.database_name}"
+
+ def _initialize_migration_configs(self) -> List[MigrationConfig]:
+ """初始化迁移配置"""
+ return [
+ # 表情包迁移配置
+ MigrationConfig(
+ mongo_collection="emoji",
+ target_model=Emoji,
+ field_mapping={
+ "full_path": "full_path",
+ "format": "format",
+ "hash": "emoji_hash",
+ "description": "description",
+ "emotion": "emotion",
+ "usage_count": "usage_count",
+ "last_used_time": "last_used_time",
+ "last_used_time": "record_time" # 这个纯粹是为了应付整体映射格式,实际上直接用当前时间戳填了record_time
+ },
+ unique_fields=["full_path", "emoji_hash"]
+ ),
+
+ # 聊天流迁移配置
+ MigrationConfig(
+ mongo_collection="chat_streams",
+ target_model=ChatStreams,
+ field_mapping={
+ "stream_id": "stream_id",
+ "create_time": "create_time",
+ "group_info.platform": "group_platform",# 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。
+ "group_info.group_id": "group_id", # 同上
+ "group_info.group_name": "group_name", # 同上
+ "last_active_time": "last_active_time",
+ "platform": "platform",
+ "user_info.platform": "user_platform",
+ "user_info.user_id": "user_id",
+ "user_info.user_nickname": "user_nickname",
+ "user_info.user_cardname": "user_cardname"
+ },
+ unique_fields=["stream_id"]
+ ),
+
+ # LLM使用记录迁移配置
+ MigrationConfig(
+ mongo_collection="llm_usage",
+ target_model=LLMUsage,
+ field_mapping={
+ "model_name": "model_name",
+ "user_id": "user_id",
+ "request_type": "request_type",
+ "endpoint": "endpoint",
+ "prompt_tokens": "prompt_tokens",
+ "completion_tokens": "completion_tokens",
+ "total_tokens": "total_tokens",
+ "cost": "cost",
+ "status": "status",
+ "timestamp": "timestamp"
+ },
+ unique_fields=["user_id", "timestamp"] # 组合唯一性
+ ),
+
+ # 消息迁移配置
+ MigrationConfig(
+ mongo_collection="messages",
+ target_model=Messages,
+ field_mapping={
+ "message_id": "message_id",
+ "time": "time",
+ "chat_id": "chat_id",
+ "chat_info.stream_id": "chat_info_stream_id",
+ "chat_info.platform": "chat_info_platform",
+ "chat_info.user_info.platform": "chat_info_user_platform",
+ "chat_info.user_info.user_id": "chat_info_user_id",
+ "chat_info.user_info.user_nickname": "chat_info_user_nickname",
+ "chat_info.user_info.user_cardname": "chat_info_user_cardname",
+ "chat_info.group_info.platform": "chat_info_group_platform",
+ "chat_info.group_info.group_id": "chat_info_group_id",
+ "chat_info.group_info.group_name": "chat_info_group_name",
+ "chat_info.create_time": "chat_info_create_time",
+ "chat_info.last_active_time": "chat_info_last_active_time",
+ "user_info.platform": "user_platform",
+ "user_info.user_id": "user_id",
+ "user_info.user_nickname": "user_nickname",
+ "user_info.user_cardname": "user_cardname",
+ "processed_plain_text": "processed_plain_text",
+ "detailed_plain_text": "detailed_plain_text",
+ "memorized_times": "memorized_times"
+ },
+ unique_fields=["message_id"]
+ ),
+
+ # 图片迁移配置
+ MigrationConfig(
+ mongo_collection="images",
+ target_model=Images,
+ field_mapping={
+ "hash": "emoji_hash",
+ "description": "description",
+ "path": "path",
+ "timestamp": "timestamp",
+ "type": "type"
+ },
+ unique_fields=["path"]
+ ),
+
+ # 图片描述迁移配置
+ MigrationConfig(
+ mongo_collection="image_descriptions",
+ target_model=ImageDescriptions,
+ field_mapping={
+ "type": "type",
+ "hash": "image_description_hash",
+ "description": "description",
+ "timestamp": "timestamp"
+ },
+ unique_fields=["image_description_hash", "type"]
+ ),
+
+ # 在线时长迁移配置
+ MigrationConfig(
+ mongo_collection="online_time",
+ target_model=OnlineTime,
+ field_mapping={
+ "timestamp": "timestamp",
+ "duration": "duration",
+ "start_timestamp": "start_timestamp",
+ "end_timestamp": "end_timestamp"
+ },
+ unique_fields=["start_timestamp", "end_timestamp"]
+ ),
+
+ # 个人信息迁移配置
+ MigrationConfig(
+ mongo_collection="person_info",
+ target_model=PersonInfo,
+ field_mapping={
+ "person_id": "person_id",
+ "person_name": "person_name",
+ "name_reason": "name_reason",
+ "platform": "platform",
+ "user_id": "user_id",
+ "nickname": "nickname",
+ "relationship_value": "relationship_value",
+ "konw_time": "know_time",
+ "msg_interval": "msg_interval",
+ "msg_interval_list": "msg_interval_list"
+ },
+ unique_fields=["person_id"]
+ ),
+
+ # 知识库迁移配置
+ MigrationConfig(
+ mongo_collection="knowledges",
+ target_model=Knowledges,
+ field_mapping={
+ "content": "content",
+ "embedding": "embedding"
+ },
+ unique_fields=["content"] # 假设内容唯一
+ ),
+
+ # 思考日志迁移配置
+ MigrationConfig(
+ mongo_collection="thinking_log",
+ target_model=ThinkingLog,
+ field_mapping={
+ "chat_id": "chat_id",
+ "trigger_text": "trigger_text",
+ "response_text": "response_text",
+ "trigger_info": "trigger_info_json",
+ "response_info": "response_info_json",
+ "timing_results": "timing_results_json",
+ "chat_history": "chat_history_json",
+ "chat_history_in_thinking": "chat_history_in_thinking_json",
+ "chat_history_after_response": "chat_history_after_response_json",
+ "heartflow_data": "heartflow_data_json",
+ "reasoning_data": "reasoning_data_json",
+ },
+ unique_fields=["chat_id", "created_at"]
+ ),
+
+ # 图节点迁移配置
+ MigrationConfig(
+ mongo_collection="graph_data.nodes",
+ target_model=GraphNodes,
+ field_mapping={
+ "concept": "concept",
+ "memory_items": "memory_items",
+ "hash": "hash",
+ "created_time": "created_time",
+ "last_modified": "last_modified"
+ },
+ unique_fields=["concept"]
+ ),
+
+ # 图边迁移配置
+ MigrationConfig(
+ mongo_collection="graph_data.edges",
+ target_model=GraphEdges,
+ field_mapping={
+ "source": "source",
+ "target": "target",
+ "strength": "strength",
+ "hash": "hash",
+ "created_time": "created_time",
+ "last_modified": "last_modified"
+ },
+ unique_fields=["source", "target"] # 组合唯一性
+ )
+ ]
+
+ def connect_mongodb(self) -> bool:
+ """连接到MongoDB"""
+ try:
+ self.mongo_client = MongoClient(
+ self.mongo_uri,
+ serverSelectionTimeoutMS=5000,
+ connectTimeoutMS=10000,
+ maxPoolSize=10
+ )
+
+ # 测试连接
+ self.mongo_client.admin.command('ping')
+ self.mongo_db = self.mongo_client[self.database_name]
+
+ logger.info(f"成功连接到MongoDB: {self.database_name}")
+ return True
+
+ except ConnectionFailure as e:
+ logger.error(f"MongoDB连接失败: {e}")
+ return False
+ except Exception as e:
+ logger.error(f"MongoDB连接异常: {e}")
+ return False
+
+ def disconnect_mongodb(self):
+ """断开MongoDB连接"""
+ if self.mongo_client:
+ self.mongo_client.close()
+ logger.info("MongoDB连接已关闭")
+
+ def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
+ """获取嵌套字段的值"""
+ if "." not in field_path:
+ return document.get(field_path)
+
+ parts = field_path.split(".")
+ value = document
+
+ for part in parts:
+ if isinstance(value, dict):
+ value = value.get(part)
+ else:
+ return None
+
+ if value is None:
+ break
+
+ return value
+
+ def _convert_field_value(self, value: Any, target_field: Field) -> Any:
+ """根据目标字段类型转换值"""
+ if value is None:
+ return None
+
+ field_type = target_field.__class__.__name__
+
+ try:
+ if target_field.name == "record_time" and field_type == "DateTimeField":
+ return datetime.now()
+
+ if target_field.name == "record_time" and field_type == "DateTimeField":
+ return self._convert_record_time(value)
+
+ if field_type in ["CharField", "TextField"]:
+ if isinstance(value, (list, dict)):
+ return json.dumps(value, ensure_ascii=False)
+ return str(value) if value is not None else ""
+
+ elif field_type == "IntegerField":
+ if isinstance(value, str):
+ # 处理字符串数字
+ clean_value = value.strip()
+ if clean_value.replace('.', '').replace('-', '').isdigit():
+ return int(float(clean_value))
+ return 0
+ return int(value) if value is not None else 0
+
+ elif field_type in ["FloatField", "DoubleField"]:
+ return float(value) if value is not None else 0.0
+
+ elif field_type == "BooleanField":
+ if isinstance(value, str):
+ return value.lower() in ('true', '1', 'yes', 'on')
+ return bool(value)
+
+ elif field_type == "DateTimeField":
+ if isinstance(value, (int, float)):
+ return datetime.fromtimestamp(value)
+ elif isinstance(value, str):
+ try:
+ # 尝试解析ISO格式日期
+ return datetime.fromisoformat(value.replace('Z', '+00:00'))
+ except ValueError:
+ try:
+ # 尝试解析时间戳字符串
+ return datetime.fromtimestamp(float(value))
+ except ValueError:
+ return datetime.now()
+ return datetime.now()
+
+ return value
+
+ except (ValueError, TypeError) as e:
+ logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
+ return self._get_default_value_for_field(target_field)
+
+ def _get_default_value_for_field(self, field: Field) -> Any:
+ """获取字段的默认值"""
+ field_type = field.__class__.__name__
+
+ if hasattr(field, 'default') and field.default is not None:
+ return field.default
+
+ if field.null:
+ return None
+
+ # 根据字段类型返回默认值
+ if field_type in ["CharField", "TextField"]:
+ return ""
+ elif field_type == "IntegerField":
+ return 0
+ elif field_type in ["FloatField", "DoubleField"]:
+ return 0.0
+ elif field_type == "BooleanField":
+ return False
+ elif field_type == "DateTimeField":
+ return datetime.now()
+
+ return None
+
+ def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any],
+ unique_fields: List[str]) -> bool:
+ """根据唯一字段检查重复"""
+ if not unique_fields:
+ return False
+
+ try:
+ query = model.select()
+ for field_name in unique_fields:
+ if field_name in data and data[field_name] is not None:
+ field_obj = getattr(model, field_name)
+ query = query.where(field_obj == data[field_name])
+
+ return query.exists()
+ except Exception as e:
+ logger.debug(f"重复检查失败: {e}")
+ return False
+
+ def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
+ """使用ORM创建模型实例"""
+ try:
+ # 过滤掉不存在的字段
+ valid_data = {}
+ for field_name, value in data.items():
+ if hasattr(model, field_name):
+ valid_data[field_name] = value
+ else:
+ logger.debug(f"跳过未知字段: {field_name}")
+
+ # 创建实例
+ instance = model.create(**valid_data)
+ return instance
+
+ except IntegrityError as e:
+ # 处理唯一约束冲突等完整性错误
+ logger.debug(f"完整性约束冲突: {e}")
+ return None
+ except Exception as e:
+ logger.error(f"创建模型实例失败: {e}")
+ return None
+
+ def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
+ """迁移单个集合 - 使用ORM方式"""
+ stats = MigrationStats()
+
+ logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
+
+ try:
+ # 获取MongoDB集合
+ mongo_collection = self.mongo_db[config.mongo_collection]
+ stats.total_documents = mongo_collection.count_documents({})
+
+ if stats.total_documents == 0:
+ logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
+ return stats
+
+ logger.info(f"待迁移文档数量: {stats.total_documents}")
+
+ # 逐个处理文档
+ batch_count = 0
+ for mongo_doc in mongo_collection.find().batch_size(config.batch_size):
+ try:
+ stats.processed_count += 1
+ doc_id = mongo_doc.get('_id', 'unknown')
+
+ # 构建目标数据
+ target_data = {}
+ for mongo_field, sqlite_field in config.field_mapping.items():
+ value = self._get_nested_value(mongo_doc, mongo_field)
+
+ # 获取目标字段对象并转换类型
+ if hasattr(config.target_model, sqlite_field):
+ field_obj = getattr(config.target_model, sqlite_field)
+ converted_value = self._convert_field_value(value, field_obj)
+ target_data[sqlite_field] = converted_value
+
+ # 重复检查
+ if config.skip_duplicates and self._check_duplicate_by_unique_fields(
+ config.target_model, target_data, config.unique_fields
+ ):
+ stats.duplicate_count += 1
+ stats.skipped_count += 1
+ logger.debug(f"跳过重复记录: {doc_id}")
+ continue
+
+ # 使用ORM创建实例
+ with db.atomic(): # 每个实例的事务保护
+ instance = self._create_model_instance(config.target_model, target_data)
+
+ if instance:
+ stats.success_count += 1
+ else:
+ stats.add_error(doc_id, "ORM创建实例失败", target_data)
+
+ except Exception as e:
+ doc_id = mongo_doc.get('_id', 'unknown')
+ stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
+ logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
+
+ # 进度报告
+ batch_count += 1
+ if batch_count % config.batch_size == 0:
+ progress = (stats.processed_count / stats.total_documents) * 100
+ logger.info(
+ f"迁移进度: {stats.processed_count}/{stats.total_documents} "
+ f"({progress:.1f}%) - 成功: {stats.success_count}, "
+ f"错误: {stats.error_count}, 跳过: {stats.skipped_count}"
+ )
+
+ logger.info(
+ f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
+ f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
+ f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}"
+ )
+
+ except Exception as e:
+ logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
+ stats.add_error("collection_error", str(e))
+
+ return stats
+
+ def migrate_all(self) -> Dict[str, MigrationStats]:
+ """执行所有迁移任务"""
+ logger.info("开始执行数据库迁移...")
+
+ if not self.connect_mongodb():
+ logger.error("无法连接到MongoDB,迁移终止")
+ return {}
+
+ all_stats = {}
+
+ try:
+ for config in self.migration_configs:
+ logger.info(f"\n开始处理集合: {config.mongo_collection}")
+ stats = self.migrate_collection(config)
+ all_stats[config.mongo_collection] = stats
+
+ # 错误率检查
+ if stats.processed_count > 0:
+ error_rate = stats.error_count / stats.processed_count
+ if error_rate > 0.1: # 错误率超过10%
+ logger.warning(
+ f"集合 {config.mongo_collection} 错误率较高: {error_rate:.1%} "
+ f"({stats.error_count}/{stats.processed_count})"
+ )
+
+ finally:
+ self.disconnect_mongodb()
+
+ self._print_migration_summary(all_stats)
+ return all_stats
+
+ def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
+ """打印迁移汇总信息"""
+ logger.info("\n" + "="*60)
+ logger.info("数据迁移汇总报告")
+ logger.info("="*60)
+
+ total_processed = sum(stats.processed_count for stats in all_stats.values())
+ total_success = sum(stats.success_count for stats in all_stats.values())
+ total_errors = sum(stats.error_count for stats in all_stats.values())
+ total_skipped = sum(stats.skipped_count for stats in all_stats.values())
+ total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
+
+ # 表头
+ logger.info(f"{'集合名称':<20} | {'处理':<6} | {'成功':<6} | {'错误':<6} | {'跳过':<6} | {'重复':<6} | {'成功率':<8}")
+ logger.info("-" * 75)
+
+ for collection_name, stats in all_stats.items():
+ success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
+ logger.info(
+ f"{collection_name:<20} | "
+ f"{stats.processed_count:<6} | "
+ f"{stats.success_count:<6} | "
+ f"{stats.error_count:<6} | "
+ f"{stats.skipped_count:<6} | "
+ f"{stats.duplicate_count:<6} | "
+ f"{success_rate:<7.1f}%"
+ )
+
+ logger.info("-" * 75)
+ total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
+ logger.info(
+ f"{'总计':<20} | "
+ f"{total_processed:<6} | "
+ f"{total_success:<6} | "
+ f"{total_errors:<6} | "
+ f"{total_skipped:<6} | "
+ f"{total_duplicates:<6} | "
+ f"{total_success_rate:<7.1f}%"
+ )
+
+ if total_errors > 0:
+ logger.warning(f"\n⚠️ 存在 {total_errors} 个错误,请检查日志详情")
+
+ if total_duplicates > 0:
+ logger.info(f"ℹ️ 跳过了 {total_duplicates} 个重复记录")
+
+ logger.info("="*60)
+
+ def add_migration_config(self, config: MigrationConfig):
+ """添加新的迁移配置"""
+ self.migration_configs.append(config)
+
+ def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
+ """迁移单个指定的集合"""
+ config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
+ if not config:
+ logger.error(f"未找到集合 {collection_name} 的迁移配置")
+ return None
+
+ if not self.connect_mongodb():
+ logger.error("无法连接到MongoDB")
+ return None
+
+ try:
+ stats = self.migrate_collection(config)
+ self._print_migration_summary({collection_name: stats})
+ return stats
+ finally:
+ self.disconnect_mongodb()
+
+ def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
+ """导出错误报告"""
+ error_report = {
+ 'timestamp': datetime.now().isoformat(),
+ 'summary': {
+ collection: {
+ 'total': stats.total_documents,
+ 'processed': stats.processed_count,
+ 'success': stats.success_count,
+ 'errors': stats.error_count,
+ 'skipped': stats.skipped_count,
+ 'duplicates': stats.duplicate_count
+ }
+ for collection, stats in all_stats.items()
+ },
+ 'errors': {
+ collection: stats.errors
+ for collection, stats in all_stats.items()
+ if stats.errors
+ }
+ }
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(error_report, f, ensure_ascii=False, indent=2)
+ logger.info(f"错误报告已导出到: {filepath}")
+ except Exception as e:
+ logger.error(f"导出错误报告失败: {e}")
+
+
+def main():
+ """主程序入口"""
+ migrator = MongoToSQLiteMigrator()
+
+ # 执行迁移
+ migration_results = migrator.migrate_all()
+
+ # 导出错误报告(如果有错误)
+ if any(stats.error_count > 0 for stats in migration_results.values()):
+ error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
+ migrator.export_error_report(migration_results, error_report_path)
+
+ logger.info("数据迁移完成!")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
From bd58fa48ce000f2c6fbc4aad60e30835a2c737a6 Mon Sep 17 00:00:00 2001
From: A0000Xz <122650088+A0000Xz@users.noreply.github.com>
Date: Wed, 28 May 2025 18:00:22 +0800
Subject: [PATCH 02/13] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=95=B0?=
=?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=81=E7=A7=BB=E5=B7=A5=E5=85=B7=E7=9A=84?=
=?UTF-8?q?=E5=90=AF=E5=8A=A8=E8=84=9A=E6=9C=AC?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mongodb_to_sqlite .bat | 72 ++++++++++++++++++++++++++++++++++++++++++
1 file changed, 72 insertions(+)
create mode 100644 mongodb_to_sqlite .bat
diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite .bat
new file mode 100644
index 00000000..bc4c5855
--- /dev/null
+++ b/mongodb_to_sqlite .bat
@@ -0,0 +1,72 @@
+@echo off
+CHCP 65001 > nul
+setlocal enabledelayedexpansion
+
+echo 你需要选择启动方式,输入字母来选择:
+echo V = 不知道什么意思就输入 V
+echo C = 输入 C 使用 Conda 环境
+echo.
+choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V
+
+set "ENV_TYPE="
+if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA"
+if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV"
+
+if "%ENV_TYPE%" == "CONDA" goto activate_conda
+if "%ENV_TYPE%" == "VENV" goto activate_venv
+
+REM 如果 choice 超时或返回意外值,默认使用 venv
+echo WARN: Invalid selection or timeout from choice. Defaulting to VENV.
+set "ENV_TYPE=VENV"
+goto activate_venv
+
+:activate_conda
+ set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: "
+ if not defined CONDA_ENV_NAME (
+ echo 错误: 未输入 Conda 环境名称.
+ pause
+ exit /b 1
+ )
+ echo 选择: Conda '!CONDA_ENV_NAME!'
+ REM 激活Conda环境
+ call conda activate !CONDA_ENV_NAME!
+ if !ERRORLEVEL! neq 0 (
+ echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在.
+ pause
+ exit /b 1
+ )
+ goto env_activated
+
+:activate_venv
+ echo Selected: venv (default or selected)
+ REM 查找venv虚拟环境
+ set "venv_path=%~dp0venv\Scripts\activate.bat"
+ if not exist "%venv_path%" (
+ echo Error: venv not found. Ensure the venv directory exists alongside the script.
+ pause
+ exit /b 1
+ )
+ REM 激活虚拟环境
+ call "%venv_path%"
+ if %ERRORLEVEL% neq 0 (
+ echo Error: Failed to activate venv virtual environment.
+ pause
+ exit /b 1
+ )
+ goto env_activated
+
+:env_activated
+echo Environment activated successfully!
+
+REM --- 后续脚本执行 ---
+
+REM 运行预处理脚本
+python "%~dp0scripts\mongodb_to_sqlite.py"
+if %ERRORLEVEL% neq 0 (
+ echo Error: raw_data_preprocessor.py execution failed.
+ pause
+ exit /b 1
+)
+
+echo All processing steps completed!
+pause
\ No newline at end of file
From 4ca37440720ad2b9100555d4ae118b4421224f03 Mon Sep 17 00:00:00 2001
From: A0000Xz <122650088+A0000Xz@users.noreply.github.com>
Date: Wed, 28 May 2025 18:31:59 +0800
Subject: [PATCH 03/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E7=BB=86=E8=8A=82?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mongodb_to_sqlite .bat | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite .bat
index bc4c5855..5c029e44 100644
--- a/mongodb_to_sqlite .bat
+++ b/mongodb_to_sqlite .bat
@@ -63,7 +63,7 @@ REM --- 后续脚本执行 ---
REM 运行预处理脚本
python "%~dp0scripts\mongodb_to_sqlite.py"
if %ERRORLEVEL% neq 0 (
- echo Error: raw_data_preprocessor.py execution failed.
+ echo Error: mongodb_to_sqlite.py execution failed.
pause
exit /b 1
)
From d84b030fa63013d95c210ad3e820a2ed39159a93 Mon Sep 17 00:00:00 2001
From: A0000Xz <122650088+A0000Xz@users.noreply.github.com>
Date: Wed, 28 May 2025 18:32:38 +0800
Subject: [PATCH 04/13] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=80=E4=BA=9B?=
=?UTF-8?q?=E9=94=99=E8=AF=AF?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
scripts/mongodb_to_sqlite.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py
index 83e08679..5ff346af 100644
--- a/scripts/mongodb_to_sqlite.py
+++ b/scripts/mongodb_to_sqlite.py
@@ -258,7 +258,7 @@ class MongoToSQLiteMigrator:
"heartflow_data": "heartflow_data_json",
"reasoning_data": "reasoning_data_json",
},
- unique_fields=["chat_id", "created_at"]
+ unique_fields=["chat_id", "trigger_text"]
),
# 图节点迁移配置
@@ -350,9 +350,6 @@ class MongoToSQLiteMigrator:
try:
if target_field.name == "record_time" and field_type == "DateTimeField":
return datetime.now()
-
- if target_field.name == "record_time" and field_type == "DateTimeField":
- return self._convert_record_time(value)
if field_type in ["CharField", "TextField"]:
if isinstance(value, (list, dict)):
From bc489861d30703fc4062494ef327226c57452485 Mon Sep 17 00:00:00 2001
From: SengokuCola <1026294844@qq.com>
Date: Wed, 28 May 2025 20:41:46 +0800
Subject: [PATCH 05/13] =?UTF-8?q?feat=EF=BC=9A=E4=B8=BAnormal=5Fchat?=
=?UTF-8?q?=E6=8F=90=E4=BE=9B=E9=80=89=E9=A1=B9=EF=BC=8C=E6=9C=89=E6=95=88?=
=?UTF-8?q?=E6=8E=A7=E5=88=B6=E5=9B=9E=E5=A4=8D=E9=A2=91=E7=8E=87?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
changelogs/changelog.md | 9 +-
.../focus_chat/heartflow_message_processor.py | 18 ++-
src/chat/heart_flow/background_tasks.py | 74 +---------
src/chat/heart_flow/heartflow.py | 8 +-
src/chat/heart_flow/mai_state_manager.py | 135 ------------------
src/chat/heart_flow/sub_heartflow.py | 9 +-
src/chat/heart_flow/subheartflow_manager.py | 5 +-
src/chat/normal_chat/normal_chat.py | 66 +++++----
src/chat/normal_chat/normal_chat_utils.py | 33 +++++
src/common/logger.py | 4 +-
src/config/official_configs.py | 5 +-
template/bot_config_template.toml | 56 ++++----
12 files changed, 138 insertions(+), 284 deletions(-)
delete mode 100644 src/chat/heart_flow/mai_state_manager.py
create mode 100644 src/chat/normal_chat/normal_chat_utils.py
diff --git a/changelogs/changelog.md b/changelogs/changelog.md
index 55d87d40..c4dc8b3b 100644
--- a/changelogs/changelog.md
+++ b/changelogs/changelog.md
@@ -8,8 +8,8 @@
- 重构表情包模块
- 移除日程系统
-**重构专注聊天(HFC)**
-- 模块化HFC,可以自定义不同的部件
+**重构专注聊天(HFC - focus_chat)**
+- 模块化设计,可以自定义不同的部件
- 观察器(获取信息)
- 信息处理器(处理信息)
- 重构:聊天思考(子心流)处理器
@@ -31,6 +31,10 @@
- 在专注模式下,麦麦可以决定自行发送语音消息(需要搭配tts适配器)
- 优化reply,减少复读
+**优化普通聊天(normal_chat)**
+- 增加了talk_frequency参数来有效控制回复频率
+- 优化了进入和离开normal_chat的方式
+
**新增表达方式学习**
- 在专注模式下,麦麦可以有独特的表达方式
- 自主学习群聊中的表达方式,更贴近群友
@@ -50,6 +54,7 @@
**人格**
- 简化了人格身份的配置
+- 优化了在focus模式下人格的表现和稳定性
**数据库重构**
- 移除了默认使用MongoDB,采用轻量sqlite
diff --git a/src/chat/focus_chat/heartflow_message_processor.py b/src/chat/focus_chat/heartflow_message_processor.py
index c1efeb52..18349d7a 100644
--- a/src/chat/focus_chat/heartflow_message_processor.py
+++ b/src/chat/focus_chat/heartflow_message_processor.py
@@ -8,6 +8,7 @@ from ..utils.utils import is_mentioned_bot_in_message
from src.chat.heart_flow.heartflow import heartflow
from src.common.logger_manager import get_logger
from ..message_receive.chat_stream import chat_manager
+import math
# from ..message_receive.message_buffer import message_buffer
from ..utils.timer_calculator import Timer
@@ -69,6 +70,15 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
message.processed_plain_text,
fast_retrieval=True,
)
+ text_len = len(message.processed_plain_text)
+ # 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
+ # 采用对数函数实现递减增长
+
+ base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
+ base_interest = min(max(base_interest, 0.01), 0.05)
+
+ interested_rate += base_interest
+
logger.trace(f"记忆激活率: {interested_rate:.2f}")
if is_mentioned:
@@ -205,17 +215,15 @@ class HeartFCMessageReceiver:
# 6. 兴趣度计算与更新
interested_rate, is_mentioned = await _calculate_interest(message)
- # await subheartflow.interest_chatting.increase_interest(value=interested_rate)
- subheartflow.add_interest_message(message, interested_rate, is_mentioned)
+ subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
# 7. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
- current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
+ # current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
logger.info(
- f"[{current_time}][{mes_name}]"
+ f"[{mes_name}]"
f"{userinfo.user_nickname}:"
f"{message.processed_plain_text}"
- f"[激活: {interested_rate:.1f}]"
)
# 8. 关系处理
diff --git a/src/chat/heart_flow/background_tasks.py b/src/chat/heart_flow/background_tasks.py
index 9479804e..db260456 100644
--- a/src/chat/heart_flow/background_tasks.py
+++ b/src/chat/heart_flow/background_tasks.py
@@ -2,25 +2,18 @@ import asyncio
import traceback
from typing import Optional, Coroutine, Callable, Any, List
from src.common.logger_manager import get_logger
-from src.chat.heart_flow.mai_state_manager import MaiStateManager, MaiStateInfo
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
from src.config.config import global_config
logger = get_logger("background_tasks")
-# 新增兴趣评估间隔
-INTEREST_EVAL_INTERVAL_SECONDS = 5
-# 新增聊天超时检查间隔
-NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS = 60
-# 新增状态评估间隔
-HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS = 20
+
+
# 新增私聊激活检查间隔
PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS = 5 # 与兴趣评估类似,设为5秒
CLEANUP_INTERVAL_SECONDS = 1200
-STATE_UPDATE_INTERVAL_SECONDS = 60
-LOG_INTERVAL_SECONDS = 3
async def _run_periodic_loop(
@@ -55,19 +48,13 @@ class BackgroundTaskManager:
def __init__(
self,
- mai_state_info: MaiStateInfo, # Needs current state info
- mai_state_manager: MaiStateManager,
subheartflow_manager: SubHeartflowManager,
):
- self.mai_state_info = mai_state_info
- self.mai_state_manager = mai_state_manager
self.subheartflow_manager = subheartflow_manager
# Task references
- self._state_update_task: Optional[asyncio.Task] = None
self._cleanup_task: Optional[asyncio.Task] = None
self._hf_judge_state_update_task: Optional[asyncio.Task] = None
- self._into_focus_task: Optional[asyncio.Task] = None
self._private_chat_activation_task: Optional[asyncio.Task] = None # 新增私聊激活任务引用
self._tasks: List[Optional[asyncio.Task]] = [] # Keep track of all tasks
@@ -80,15 +67,7 @@ class BackgroundTaskManager:
- 将任务引用保存到任务列表
"""
- # 任务配置列表: (任务函数, 任务名称, 日志级别, 额外日志信息, 任务对象引用属性名)
- task_configs = [
- (
- lambda: self._run_state_update_cycle(STATE_UPDATE_INTERVAL_SECONDS),
- "debug",
- f"聊天状态更新任务已启动 间隔:{STATE_UPDATE_INTERVAL_SECONDS}s",
- "_state_update_task",
- ),
- ]
+ task_configs = []
# 根据 chat_mode 条件添加其他任务
if not (global_config.chat.chat_mode == "normal"):
@@ -108,12 +87,6 @@ class BackgroundTaskManager:
f"私聊激活检查任务已启动 间隔:{PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS}s",
"_private_chat_activation_task",
),
- # (
- # self._run_into_focus_cycle,
- # "debug", # 设为debug,避免过多日志
- # f"专注评估任务已启动 间隔:{INTEREST_EVAL_INTERVAL_SECONDS}s",
- # "_into_focus_task",
- # )
]
)
else:
@@ -163,33 +136,9 @@ class BackgroundTaskManager:
# 第三步:清空任务列表
self._tasks = [] # 重置任务列表
- async def _perform_state_update_work(self):
- """执行状态更新工作"""
- previous_status = self.mai_state_info.get_current_state()
- next_state = self.mai_state_manager.check_and_decide_next_state(self.mai_state_info)
-
- state_changed = False
-
- if next_state is not None:
- state_changed = self.mai_state_info.update_mai_status(next_state)
-
- # 处理保持离线状态的特殊情况
- if not state_changed and next_state == previous_status == self.mai_state_info.mai_status.OFFLINE:
- self.mai_state_info.reset_state_timer()
- logger.debug("[后台任务] 保持离线状态并重置计时器")
- state_changed = True # 触发后续处理
-
- if state_changed:
- current_state = self.mai_state_info.get_current_state()
-
# 状态转换处理
- if (
- current_state == self.mai_state_info.mai_status.OFFLINE
- and previous_status != self.mai_state_info.mai_status.OFFLINE
- ):
- logger.info("检测到离线,停用所有子心流")
- await self.subheartflow_manager.deactivate_all_subflows()
+
async def _perform_cleanup_work(self):
"""执行子心流清理任务
@@ -216,27 +165,12 @@ class BackgroundTaskManager:
# 记录最终清理结果
logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流")
- # --- 新增兴趣评估工作函数 ---
- # async def _perform_into_focus_work(self):
- # """执行一轮子心流兴趣评估与提升检查。"""
- # # 直接调用 subheartflow_manager 的方法,并传递当前状态信息
- # await self.subheartflow_manager.sbhf_normal_into_focus()
-
- async def _run_state_update_cycle(self, interval: int):
- await _run_periodic_loop(task_name="State Update", interval=interval, task_func=self._perform_state_update_work)
async def _run_cleanup_cycle(self):
await _run_periodic_loop(
task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work
)
- # --- 新增兴趣评估任务运行器 ---
- # async def _run_into_focus_cycle(self):
- # await _run_periodic_loop(
- # task_name="Into Focus",
- # interval=INTEREST_EVAL_INTERVAL_SECONDS,
- # task_func=self._perform_into_focus_work,
- # )
# 新增私聊激活任务运行器
async def _run_private_chat_activation_cycle(self, interval: int):
diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py
index e1f8d957..7e12135a 100644
--- a/src/chat/heart_flow/heartflow.py
+++ b/src/chat/heart_flow/heartflow.py
@@ -1,7 +1,6 @@
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from src.common.logger_manager import get_logger
from typing import Any, Optional, List
-from src.chat.heart_flow.mai_state_manager import MaiStateInfo, MaiStateManager
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
from src.chat.heart_flow.background_tasks import BackgroundTaskManager # Import BackgroundTaskManager
@@ -16,17 +15,12 @@ class Heartflow:
"""
def __init__(self):
- # 状态管理相关
- self.current_state: MaiStateInfo = MaiStateInfo() # 当前状态信息
- self.mai_state_manager: MaiStateManager = MaiStateManager() # 状态决策管理器
# 子心流管理 (在初始化时传入 current_state)
- self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
+ self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager()
# 后台任务管理器 (整合所有定时任务)
self.background_task_manager: BackgroundTaskManager = BackgroundTaskManager(
- mai_state_info=self.current_state,
- mai_state_manager=self.mai_state_manager,
subheartflow_manager=self.subheartflow_manager,
)
diff --git a/src/chat/heart_flow/mai_state_manager.py b/src/chat/heart_flow/mai_state_manager.py
deleted file mode 100644
index 81f03dec..00000000
--- a/src/chat/heart_flow/mai_state_manager.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import enum
-import time
-import random
-from typing import List, Tuple, Optional
-from src.common.logger_manager import get_logger
-from src.manager.mood_manager import mood_manager
-
-logger = get_logger("mai_state")
-
-
-class MaiState(enum.Enum):
- """
- 聊天状态:
- OFFLINE: 不在线:回复概率极低,不会进行任何聊天
- NORMAL_CHAT: 正常看手机:回复概率较高,会进行一些普通聊天和少量的专注聊天
- FOCUSED_CHAT: 专注聊天:回复概率极高,会进行专注聊天和少量的普通聊天
- """
-
- OFFLINE = "不在线"
- NORMAL_CHAT = "正常看手机"
- FOCUSED_CHAT = "专心看手机"
-
-
-class MaiStateInfo:
- def __init__(self):
- self.mai_status: MaiState = MaiState.NORMAL_CHAT # 初始状态改为 NORMAL_CHAT
- self.mai_status_history: List[Tuple[MaiState, float]] = [] # 历史状态,包含 状态,时间戳
- self.last_status_change_time: float = time.time() # 状态最后改变时间
- self.last_min_check_time: float = time.time() # 上次1分钟规则检查时间
-
- # Mood management is now part of MaiStateInfo
- self.mood_manager = mood_manager # Use singleton instance
-
- def update_mai_status(self, new_status: MaiState) -> bool:
- """
- 更新聊天状态。
-
- Args:
- new_status: 新的 MaiState 状态。
-
- Returns:
- bool: 如果状态实际发生了改变则返回 True,否则返回 False。
- """
- if new_status != self.mai_status:
- self.mai_status = new_status
- current_time = time.time()
- self.last_status_change_time = current_time
- self.last_min_check_time = current_time # Reset 1-min check on any state change
- self.mai_status_history.append((new_status, current_time))
- logger.info(f"麦麦状态更新为: {self.mai_status.value}")
- return True
- else:
- return False
-
- def reset_state_timer(self):
- """
- 重置状态持续时间计时器和一分钟规则检查计时器。
- 通常在状态保持不变但需要重新开始计时的情况下调用(例如,保持 OFFLINE)。
- """
- current_time = time.time()
- self.last_status_change_time = current_time
- self.last_min_check_time = current_time # Also reset the 1-min check timer
- logger.debug("MaiStateInfo 状态计时器已重置。")
-
- def get_mood_prompt(self) -> str:
- """获取当前的心情提示词"""
- # Delegate to the internal mood manager
- return self.mood_manager.get_mood_prompt()
-
- def get_current_state(self) -> MaiState:
- """获取当前的 MaiState"""
- return self.mai_status
-
-
-class MaiStateManager:
- """管理 Mai 的整体状态转换逻辑"""
-
- def __init__(self):
- pass
-
- @staticmethod
- def check_and_decide_next_state(current_state_info: MaiStateInfo) -> Optional[MaiState]:
- """
- 根据当前状态和规则检查是否需要转换状态,并决定下一个状态。
- """
- current_time = time.time()
- current_status = current_state_info.mai_status
- time_in_current_status = current_time - current_state_info.last_status_change_time
- next_state: Optional[MaiState] = None
-
- def _resolve_offline(candidate_state: MaiState) -> MaiState:
- if candidate_state == MaiState.OFFLINE:
- return current_status
- return candidate_state
-
- if current_status == MaiState.OFFLINE:
- logger.info("当前[离线],没看手机,思考要不要上线看看......")
- elif current_status == MaiState.NORMAL_CHAT:
- logger.info("当前在[正常看手机]思考要不要继续聊下去......")
- elif current_status == MaiState.FOCUSED_CHAT:
- logger.info("当前在[专心看手机]思考要不要继续聊下去......")
-
- if next_state is None:
- time_limit_exceeded = False
- choices_list = []
- weights = []
- rule_id = ""
-
- if current_status == MaiState.OFFLINE:
- return None
- elif current_status == MaiState.NORMAL_CHAT:
- if time_in_current_status >= 300: # NORMAL_CHAT 最多持续 300 秒
- time_limit_exceeded = True
- rule_id = "2.3 (From NORMAL_CHAT)"
- weights = [100]
- choices_list = [MaiState.FOCUSED_CHAT]
- elif current_status == MaiState.FOCUSED_CHAT:
- if time_in_current_status >= 600: # FOCUSED_CHAT 最多持续 600 秒
- time_limit_exceeded = True
- rule_id = "2.4 (From FOCUSED_CHAT)"
- weights = [100]
- choices_list = [MaiState.NORMAL_CHAT]
-
- if time_limit_exceeded:
- next_state_candidate = random.choices(choices_list, weights=weights, k=1)[0]
- resolved_candidate = _resolve_offline(next_state_candidate)
- logger.debug(
- f"规则{rule_id}:时间到,切换到 {next_state_candidate.value},resolve 为 {resolved_candidate.value}"
- )
- next_state = resolved_candidate
-
- if next_state is not None and next_state != current_status:
- return next_state
- else:
- return None
diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py
index 0d4ac281..b7267bdc 100644
--- a/src/chat/heart_flow/sub_heartflow.py
+++ b/src/chat/heart_flow/sub_heartflow.py
@@ -9,7 +9,6 @@ from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import chat_manager
from src.chat.focus_chat.heartFC_chat import HeartFChatting
from src.chat.normal_chat.normal_chat import NormalChat
-from src.chat.heart_flow.mai_state_manager import MaiStateInfo
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
from .utils_chat import get_chat_type_and_target_info
from src.config.config import global_config
@@ -23,7 +22,6 @@ class SubHeartflow:
def __init__(
self,
subheartflow_id,
- mai_states: MaiStateInfo,
):
"""子心流初始化函数
@@ -36,9 +34,6 @@ class SubHeartflow:
self.subheartflow_id = subheartflow_id
self.chat_id = subheartflow_id
- # 麦麦的状态
- self.mai_states = mai_states
-
# 这个聊天流的状态
self.chat_state: ChatStateInfo = ChatStateInfo()
self.chat_state_changed_time: float = time.time()
@@ -334,10 +329,10 @@ class SubHeartflow:
return self.normal_chat_instance.get_recent_replies(limit)
return []
- def add_interest_message(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
+ def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned)
# 如果字典长度超过10,删除最旧的消息
- if len(self.interest_dict) > 10:
+ if len(self.interest_dict) > 30:
oldest_key = next(iter(self.interest_dict))
self.interest_dict.pop(oldest_key)
diff --git a/src/chat/heart_flow/subheartflow_manager.py b/src/chat/heart_flow/subheartflow_manager.py
index e3e68f55..bad4393c 100644
--- a/src/chat/heart_flow/subheartflow_manager.py
+++ b/src/chat/heart_flow/subheartflow_manager.py
@@ -4,7 +4,6 @@ from typing import Dict, Any, Optional, List
from src.common.logger_manager import get_logger
from src.chat.message_receive.chat_stream import chat_manager
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
-from src.chat.heart_flow.mai_state_manager import MaiStateInfo
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
@@ -54,10 +53,9 @@ async def _try_set_subflow_absent_internal(subflow: "SubHeartflow", log_prefix:
class SubHeartflowManager:
"""管理所有活跃的 SubHeartflow 实例。"""
- def __init__(self, mai_state_info: MaiStateInfo):
+ def __init__(self):
self.subheartflows: Dict[Any, "SubHeartflow"] = {}
self._lock = asyncio.Lock() # 用于保护 self.subheartflows 的访问
- self.mai_state_info: MaiStateInfo = mai_state_info # 存储传入的 MaiStateInfo 实例
async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool:
"""强制改变指定子心流的状态"""
@@ -97,7 +95,6 @@ class SubHeartflowManager:
# 初始化子心流, 传入 mai_state_info
new_subflow = SubHeartflow(
subheartflow_id,
- self.mai_state_info,
)
# 首先创建并添加聊天观察者
diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py
index 67bc1ab4..cb78c701 100644
--- a/src/chat/normal_chat/normal_chat.py
+++ b/src/chat/normal_chat/normal_chat.py
@@ -1,12 +1,9 @@
import asyncio
-import statistics # 导入 statistics 模块
import time
import traceback
from random import random
from typing import List, Optional # 导入 Optional
-
from maim_message import UserInfo, Seg
-
from src.common.logger_manager import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.manager.mood_manager import mood_manager
@@ -21,6 +18,7 @@ from src.chat.message_receive.message_sender import message_manager
from src.chat.utils.utils_image import image_path_to_base64
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.normal_chat.willing.willing_manager import willing_manager
+from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
from src.config.config import global_config
logger = get_logger("normal_chat")
@@ -39,6 +37,8 @@ class NormalChat:
self.is_group_chat: bool = False
self.chat_target_info: Optional[dict] = None
+
+ self.willing_amplifier = 1
# Other sync initializations
self.gpt = NormalChatGenerator()
@@ -209,10 +209,12 @@ class NormalChat:
for msg_id, (message, interest_value, is_mentioned) in items_to_process:
try:
# 处理消息
+ self.adjust_reply_frequency()
+
await self.normal_response(
message=message,
is_mentioned=is_mentioned,
- interested_rate=interest_value,
+ interested_rate=interest_value * self.willing_amplifier,
rewind_response=False,
)
except Exception as e:
@@ -228,26 +230,18 @@ class NormalChat:
if self._disabled:
logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。")
return
- # 检查收到的消息是否属于当前实例处理的 chat stream
- if message.chat_stream.stream_id != self.stream_id:
- logger.error(
- f"[{self.stream_name}] normal_response 收到不匹配的消息 (来自 {message.chat_stream.stream_id}),预期 {self.stream_id}。已忽略。"
- )
- return
timing_results = {}
-
reply_probability = 1.0 if is_mentioned else 0.0 # 如果被提及,基础概率为1,否则需要意愿判断
-
+
# 意愿管理器:设置当前message信息
-
willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate)
# 获取回复概率
- is_willing = False
+ # is_willing = False
# 仅在未被提及或基础概率不为1时查询意愿概率
if reply_probability < 1: # 简化逻辑,如果未提及 (reply_probability 为 0),则获取意愿概率
- is_willing = True
+ # is_willing = True
reply_probability = await willing_manager.get_reply_probability(message.message_info.message_id)
if message.message_info.additional_config:
@@ -257,13 +251,13 @@ class NormalChat:
# 打印消息信息
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
- current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
+ # current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
# 使用 self.stream_id
- willing_log = f"[回复意愿:{await willing_manager.get_willing(self.stream_id):.2f}]" if is_willing else ""
+ # willing_log = f"[激活值:{await willing_manager.get_willing(self.stream_id):.2f}]" if is_willing else ""
logger.info(
- f"[{current_time}][{mes_name}]"
+ f"[{mes_name}]"
f"{message.message_info.user_info.user_nickname}:" # 使用 self.chat_stream
- f"{message.processed_plain_text}{willing_log}[概率:{reply_probability * 100:.1f}%]"
+ f"{message.processed_plain_text}[回复概率:{reply_probability * 100:.1f}%]"
)
do_reply = False
response_set = None # 初始化 response_set
@@ -346,16 +340,13 @@ class NormalChat:
# 检查是否需要切换到focus模式
await self._check_switch_to_focus()
- else:
- logger.warning(f"[{self.stream_name}] 思考消息 {thinking_id} 在发送前丢失,无法记录 info_catcher")
info_catcher.done_catch()
- # 处理表情包 (不再需要传入 chat)
with Timer("处理表情包", timing_results):
await self._handle_emoji(message, response_set[0])
- # 更新关系情绪 (不再需要传入 chat)
+
with Timer("关系更新", timing_results):
await self._update_relationship(message, response_set)
@@ -479,9 +470,6 @@ class NormalChat:
# 统计1分钟内的回复数量
recent_reply_count = sum(1 for reply in self.recent_replies if reply["time"] > one_minute_ago)
- # print(111111111111111333333333333333333333333331111111111111111111111111111111111)
- # print(recent_reply_count)
- # 如果1分钟内回复数量大于8,触发切换到focus模式
if recent_reply_count > reply_threshold:
logger.info(
f"[{self.stream_name}] 检测到1分钟内回复数量({recent_reply_count})大于{reply_threshold},触发切换到focus模式"
@@ -491,3 +479,29 @@ class NormalChat:
await self.on_switch_to_focus_callback()
except Exception as e:
logger.error(f"[{self.stream_name}] 触发切换到focus模式时出错: {e}\n{traceback.format_exc()}")
+
+
+ def adjust_reply_frequency(self, duration: int = 10):
+ """
+ 调整回复频率
+ """
+ # 获取最近30分钟内的消息统计
+ print(f"willing_amplifier: {self.willing_amplifier}")
+ stats = get_recent_message_stats(minutes=duration, chat_id=self.stream_id)
+ bot_reply_count = stats["bot_reply_count"]
+ print(f"[{self.stream_name}] 最近{duration}分钟内回复数量: {bot_reply_count}")
+ total_message_count = stats["total_message_count"]
+ print(f"[{self.stream_name}] 最近{duration}分钟内消息总数: {total_message_count}")
+
+ # 计算回复频率
+ _reply_frequency = bot_reply_count / total_message_count
+
+ # 如果回复频率低于0.5,增加回复概率
+ if bot_reply_count/duration < global_config.normal_chat.talk_frequency:
+ # differ = global_config.normal_chat.talk_frequency - reply_frequency
+ logger.info(f"[{self.stream_name}] 回复频率低于{global_config.normal_chat.talk_frequency},增加回复概率")
+ self.willing_amplifier += 0.1
+ else:
+ logger.info(f"[{self.stream_name}] 回复频率高于{global_config.normal_chat.talk_frequency},减少回复概率")
+ self.willing_amplifier -= 0.1
+
diff --git a/src/chat/normal_chat/normal_chat_utils.py b/src/chat/normal_chat/normal_chat_utils.py
new file mode 100644
index 00000000..2eae4829
--- /dev/null
+++ b/src/chat/normal_chat/normal_chat_utils.py
@@ -0,0 +1,33 @@
+import time
+from src.config.config import global_config
+from src.common.message_repository import count_messages
+
+
+def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
+ """
+ Args:
+ minutes (int): 检索的分钟数,默认30分钟
+ chat_id (str, optional): 指定的chat_id,仅统计该chat下的消息。为None时统计全部。
+ Returns:
+ dict: {"bot_reply_count": int, "total_message_count": int}
+ """
+
+ now = time.time()
+ start_time = now - minutes * 60
+ bot_id = global_config.bot.qq_account
+
+ filter_base = {"time": {"$gte": start_time}}
+ if chat_id is not None:
+ filter_base["chat_id"] = chat_id
+
+ # 总消息数
+ total_message_count = count_messages(filter_base)
+ # bot自身回复数
+ bot_filter = filter_base.copy()
+ bot_filter["user_id"] = bot_id
+ bot_reply_count = count_messages(bot_filter)
+
+ return {
+ "bot_reply_count": bot_reply_count,
+ "total_message_count": total_message_count
+ }
\ No newline at end of file
diff --git a/src/common/logger.py b/src/common/logger.py
index 3ed0fd7f..7258d619 100644
--- a/src/common/logger.py
+++ b/src/common/logger.py
@@ -271,7 +271,7 @@ CHAT_STYLE_CONFIG = {
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
},
"simple": {
- "console_format": "{time:HH:mm:ss} | 见闻 | {message}", # noqa: E501
+ "console_format": "{time:HH:mm:ss} | 见闻 | {message}", # noqa: E501
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
},
}
@@ -288,7 +288,7 @@ NORMAL_CHAT_STYLE_CONFIG = {
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}",
},
"simple": {
- "console_format": "{time:HH:mm:ss} | 一般水群 | {message}", # noqa: E501
+ "console_format": "{time:HH:mm:ss} | 一般水群 | {message}", # noqa: E501
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}",
},
}
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index 5e5199da..e353500f 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -91,7 +91,7 @@ class NormalChatConfig(ConfigBase):
max_context_size: int = 15
"""上下文长度"""
- message_buffer: bool = True
+ message_buffer: bool = False
"""消息缓冲器"""
emoji_chance: float = 0.2
@@ -103,6 +103,9 @@ class NormalChatConfig(ConfigBase):
willing_mode: str = "classical"
"""意愿模式"""
+ talk_frequency: float = 1
+ """回复频率阈值"""
+
response_willing_amplifier: float = 1.0
"""回复意愿放大系数"""
diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml
index c7863340..946e3374 100644
--- a/template/bot_config_template.toml
+++ b/template/bot_config_template.toml
@@ -1,5 +1,5 @@
[inner]
-version = "2.5.0"
+version = "2.6.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件,请在修改后将version的值进行变更
@@ -18,22 +18,24 @@ nickname = "麦麦"
alias_names = ["麦叠", "牢麦"] #仅在 专注聊天 有效
[personality]
-personality_core = "是一个积极向上的女大学生" # 建议20字以内,谁再写3000字小作文敲谁脑袋
+personality_core = "是一个积极向上的女大学生" # 建议50字以内
personality_sides = [
"用一句话或几句话描述人格的一些细节",
"用一句话或几句话描述人格的一些细节",
"用一句话或几句话描述人格的一些细节",
-]# 条数任意,不能为0
+]
+# 条数任意,不能为0
# 身份特点
-[identity] #アイデンティティがない 生まれないらららら
+#アイデンティティがない 生まれないらららら
+[identity]
identity_detail = [
"年龄为19岁",
"是女孩子",
"身高为160cm",
"有橙色的短发",
]
-# 可以描述外贸,性别,身高,职业,属性等等描述
+# 可以描述外貌,性别,身高,职业,属性等等描述
# 条数任意,不能为0
[relationship]
@@ -69,15 +71,18 @@ normal_chat_first_probability = 0.3 # 麦麦回答时选择首要模型的概率
max_context_size = 15 #上下文长度
emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发
thinking_timeout = 120 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢)
-message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟
willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现)
+talk_frequency = 1 # 麦麦回复频率,一般为1,默认频率下,30分钟麦麦回复30条(约数)
+
response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1
response_interested_rate_amplifier = 1 # 麦麦回复兴趣度放大系数,听到记忆里的内容时放大系数
-down_frequency_rate = 3 # 降低回复频率的群组回复意愿降低系数 除法
+
emoji_response_penalty = 0 # 表情包回复惩罚系数,设为0为不回复单个表情包,减少单独回复表情包的概率
mentioned_bot_inevitable_reply = false # 提及 bot 必然回复
at_bot_inevitable_reply = false # @bot 必然回复
+
+down_frequency_rate = 3 # 降低回复频率的群组回复意愿降低系数 除法
talk_frequency_down_groups = [] #降低回复频率的群号码
[focus_chat] #专注聊天
@@ -163,26 +168,8 @@ max_length = 256 # 回复允许的最大长度
max_sentence_num = 4 # 回复允许的最大句子数
enable_kaomoji_protection = false # 是否启用颜文字保护
-[maim_message]
-auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
-# 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器
-use_custom = false # 是否启用自定义的maim_message服务器,注意这需要设置新的端口,不能与.env重复
-host="127.0.0.1"
-port=8090
-mode="ws" # 支持ws和tcp两种模式
-use_wss = false # 是否使用WSS安全连接,只支持ws模式
-cert_file = "" # SSL证书文件路径,仅在use_wss=true时有效
-key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效
-[telemetry] #发送统计信息,主要是看全球有多少只麦麦
-enable = true
-
-[experimental] #实验性功能
-debug_show_chat_mode = true # 是否在回复后显示当前聊天模式
-enable_friend_chat = false # 是否启用好友聊天
-pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与回复模式独立,在0.7.0暂时无效
-
#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写
# stream = : 用于指定模型是否是使用流式输出
@@ -333,5 +320,24 @@ pri_out = 8
+[maim_message]
+auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
+# 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器
+use_custom = false # 是否启用自定义的maim_message服务器,注意这需要设置新的端口,不能与.env重复
+host="127.0.0.1"
+port=8090
+mode="ws" # 支持ws和tcp两种模式
+use_wss = false # 是否使用WSS安全连接,只支持ws模式
+cert_file = "" # SSL证书文件路径,仅在use_wss=true时有效
+key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效
+
+[telemetry] #发送统计信息,主要是看全球有多少只麦麦
+enable = true
+
+[experimental] #实验性功能
+debug_show_chat_mode = true # 是否在回复后显示当前聊天模式
+enable_friend_chat = false # 是否启用好友聊天
+pfc_chatting = false # 暂时无效
+
From 218d0d4a5dfe1948e3d494a83b413b68f4ba0a1b Mon Sep 17 00:00:00 2001
From: SengokuCola <1026294844@qq.com>
Date: Wed, 28 May 2025 20:44:26 +0800
Subject: [PATCH 06/13] pass ruff
---
.../expressors/default_expressor.py | 19 ++++++++---------
src/chat/focus_chat/heartFC_chat.py | 4 +++-
.../focus_chat/heartflow_message_processor.py | 13 ++++--------
src/chat/heart_flow/background_tasks.py | 8 +------
src/chat/heart_flow/heartflow.py | 1 -
src/chat/heart_flow/sub_heartflow.py | 1 +
src/chat/message_receive/storage.py | 2 +-
src/chat/normal_chat/normal_chat.py | 21 +++++++------------
src/chat/normal_chat/normal_chat_generator.py | 10 ++++-----
src/chat/normal_chat/normal_chat_utils.py | 5 +----
src/chat/normal_chat/normal_prompt.py | 1 -
src/chat/utils/chat_message_builder.py | 5 +++--
src/config/official_configs.py | 6 ++----
src/plugins/test_plugin/actions/__init__.py | 1 +
.../test_plugin/actions/mute_action.py | 9 ++++----
template/bot_config_template.toml | 2 +-
16 files changed, 45 insertions(+), 63 deletions(-)
diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py
index cbea0a22..382d20aa 100644
--- a/src/chat/focus_chat/expressors/default_expressor.py
+++ b/src/chat/focus_chat/expressors/default_expressor.py
@@ -16,7 +16,6 @@ from src.chat.utils.info_catcher import info_catcher_manager
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
-from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
@@ -106,10 +105,7 @@ class DefaultExpressor:
user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
- # logger.debug(f"创建思考消息:{anchor_message}")
- # logger.debug(f"创建思考消息chat:{chat}")
- # logger.debug(f"创建思考消息bot_user_info:{bot_user_info}")
- # logger.debug(f"创建思考消息messageinfo:{messageinfo}")
+
thinking_message = MessageThinking(
message_id=thinking_id,
chat_stream=chat,
@@ -281,14 +277,14 @@ class DefaultExpressor:
in_mind_reply,
target_message,
) -> str:
- prompt_personality = individuality.get_prompt(x_person=0, level=2)
+ # prompt_personality = individuality.get_prompt(x_person=0, level=2)
# Determine if it's a group chat
is_group_chat = bool(chat_stream.group_info)
# Use sender_name passed from caller for private chat, otherwise use a default for group
# Default sender_name for group chat isn't used in the group prompt template, but set for consistency
- effective_sender_name = sender_name if not is_group_chat else "某人"
+ # effective_sender_name = sender_name if not is_group_chat else "某人"
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
@@ -377,7 +373,11 @@ class DefaultExpressor:
# --- 发送器 (Sender) --- #
async def send_response_messages(
- self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str = "", display_message: str = ""
+ self,
+ anchor_message: Optional[MessageRecv],
+ response_set: List[Tuple[str, str]],
+ thinking_id: str = "",
+ display_message: str = "",
) -> Optional[MessageSending]:
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream
@@ -412,10 +412,9 @@ class DefaultExpressor:
# 为每个消息片段生成唯一ID
type = msg_text[0]
data = msg_text[1]
-
+
if global_config.experimental.debug_show_chat_mode and type == "text":
data += "ᶠ"
-
part_message_id = f"{thinking_id}_{i}"
message_segment = Seg(type=type, data=data)
diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py
index ab061d85..63ce5c30 100644
--- a/src/chat/focus_chat/heartFC_chat.py
+++ b/src/chat/focus_chat/heartFC_chat.py
@@ -379,12 +379,14 @@ class HeartFChatting:
for processor in self.processors:
processor_name = processor.__class__.log_prefix
+
# 用lambda包裹,便于传参
async def run_with_timeout(proc=processor):
return await asyncio.wait_for(
proc.process_info(observations=observations, running_memorys=running_memorys),
- timeout=PROCESSOR_TIMEOUT
+ timeout=PROCESSOR_TIMEOUT,
)
+
task = asyncio.create_task(run_with_timeout())
processor_tasks.append(task)
task_to_name_map[task] = processor_name
diff --git a/src/chat/focus_chat/heartflow_message_processor.py b/src/chat/focus_chat/heartflow_message_processor.py
index 18349d7a..8277e69f 100644
--- a/src/chat/focus_chat/heartflow_message_processor.py
+++ b/src/chat/focus_chat/heartflow_message_processor.py
@@ -1,4 +1,3 @@
-import time
import traceback
from ..memory_system.Hippocampus import HippocampusManager
from ...config.config import global_config
@@ -73,12 +72,12 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
text_len = len(message.processed_plain_text)
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
# 采用对数函数实现递减增长
-
+
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
base_interest = min(max(base_interest, 0.01), 0.05)
-
+
interested_rate += base_interest
-
+
logger.trace(f"记忆激活率: {interested_rate:.2f}")
if is_mentioned:
@@ -220,11 +219,7 @@ class HeartFCMessageReceiver:
# 7. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
- logger.info(
- f"[{mes_name}]"
- f"{userinfo.user_nickname}:"
- f"{message.processed_plain_text}"
- )
+ logger.info(f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}")
# 8. 关系处理
if global_config.relationship.give_name:
diff --git a/src/chat/heart_flow/background_tasks.py b/src/chat/heart_flow/background_tasks.py
index db260456..4bacfd0a 100644
--- a/src/chat/heart_flow/background_tasks.py
+++ b/src/chat/heart_flow/background_tasks.py
@@ -8,8 +8,6 @@ from src.config.config import global_config
logger = get_logger("background_tasks")
-
-
# 新增私聊激活检查间隔
PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS = 5 # 与兴趣评估类似,设为5秒
@@ -136,9 +134,7 @@ class BackgroundTaskManager:
# 第三步:清空任务列表
self._tasks = [] # 重置任务列表
- # 状态转换处理
-
-
+ # 状态转换处理
async def _perform_cleanup_work(self):
"""执行子心流清理任务
@@ -165,13 +161,11 @@ class BackgroundTaskManager:
# 记录最终清理结果
logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流")
-
async def _run_cleanup_cycle(self):
await _run_periodic_loop(
task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work
)
-
# 新增私聊激活任务运行器
async def _run_private_chat_activation_cycle(self, interval: int):
await _run_periodic_loop(
diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py
index 7e12135a..d58c5cde 100644
--- a/src/chat/heart_flow/heartflow.py
+++ b/src/chat/heart_flow/heartflow.py
@@ -15,7 +15,6 @@ class Heartflow:
"""
def __init__(self):
-
# 子心流管理 (在初始化时传入 current_state)
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager()
diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py
index b7267bdc..c95f606b 100644
--- a/src/chat/heart_flow/sub_heartflow.py
+++ b/src/chat/heart_flow/sub_heartflow.py
@@ -18,6 +18,7 @@ logger = get_logger("sub_heartflow")
install(extra_lines=3)
+
class SubHeartflow:
def __init__(
self,
diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py
index 54dd68f0..8c05a9ab 100644
--- a/src/chat/message_receive/storage.py
+++ b/src/chat/message_receive/storage.py
@@ -24,7 +24,7 @@ class MessageStorage:
else:
filtered_processed_plain_text = ""
- if isinstance(message,MessageSending):
+ if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py
index cb78c701..98478183 100644
--- a/src/chat/normal_chat/normal_chat.py
+++ b/src/chat/normal_chat/normal_chat.py
@@ -37,7 +37,7 @@ class NormalChat:
self.is_group_chat: bool = False
self.chat_target_info: Optional[dict] = None
-
+
self.willing_amplifier = 1
# Other sync initializations
@@ -56,7 +56,6 @@ class NormalChat:
self._disabled = False # 增加停用标志
-
async def initialize(self):
"""异步初始化,获取聊天类型和目标信息。"""
if self._initialized:
@@ -210,7 +209,7 @@ class NormalChat:
try:
# 处理消息
self.adjust_reply_frequency()
-
+
await self.normal_response(
message=message,
is_mentioned=is_mentioned,
@@ -233,7 +232,7 @@ class NormalChat:
timing_results = {}
reply_probability = 1.0 if is_mentioned else 0.0 # 如果被提及,基础概率为1,否则需要意愿判断
-
+
# 意愿管理器:设置当前message信息
willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate)
@@ -306,7 +305,7 @@ class NormalChat:
return # 不执行后续步骤
logger.info(f"[{self.stream_name}] 回复内容: {response_set}")
-
+
if self._disabled:
logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。")
return
@@ -340,13 +339,11 @@ class NormalChat:
# 检查是否需要切换到focus模式
await self._check_switch_to_focus()
-
info_catcher.done_catch()
with Timer("处理表情包", timing_results):
await self._handle_emoji(message, response_set[0])
-
with Timer("关系更新", timing_results):
await self._update_relationship(message, response_set)
@@ -479,8 +476,7 @@ class NormalChat:
await self.on_switch_to_focus_callback()
except Exception as e:
logger.error(f"[{self.stream_name}] 触发切换到focus模式时出错: {e}\n{traceback.format_exc()}")
-
-
+
def adjust_reply_frequency(self, duration: int = 10):
"""
调整回复频率
@@ -492,16 +488,15 @@ class NormalChat:
print(f"[{self.stream_name}] 最近{duration}分钟内回复数量: {bot_reply_count}")
total_message_count = stats["total_message_count"]
print(f"[{self.stream_name}] 最近{duration}分钟内消息总数: {total_message_count}")
-
+
# 计算回复频率
_reply_frequency = bot_reply_count / total_message_count
-
+
# 如果回复频率低于0.5,增加回复概率
- if bot_reply_count/duration < global_config.normal_chat.talk_frequency:
+ if bot_reply_count / duration < global_config.normal_chat.talk_frequency:
# differ = global_config.normal_chat.talk_frequency - reply_frequency
logger.info(f"[{self.stream_name}] 回复频率低于{global_config.normal_chat.talk_frequency},增加回复概率")
self.willing_amplifier += 0.1
else:
logger.info(f"[{self.stream_name}] 回复频率高于{global_config.normal_chat.talk_frequency},减少回复概率")
self.willing_amplifier -= 0.1
-
diff --git a/src/chat/normal_chat/normal_chat_generator.py b/src/chat/normal_chat/normal_chat_generator.py
index e0a88b4e..2ad1a197 100644
--- a/src/chat/normal_chat/normal_chat_generator.py
+++ b/src/chat/normal_chat/normal_chat_generator.py
@@ -64,11 +64,12 @@ class NormalChatGenerator:
async def _generate_response_with_model(self, message: MessageThinking, model: LLMRequest, thinking_id: str):
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
+ person_id = person_info_manager.get_person_id(
+ message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
+ )
- person_id = person_info_manager.get_person_id(message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id)
-
person_name = await person_info_manager.get_value(person_id, "person_name")
-
+
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
sender_name = (
f"[{message.chat_stream.user_info.user_nickname}]"
@@ -78,8 +79,7 @@ class NormalChatGenerator:
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
else:
sender_name = f"用户({message.chat_stream.user_info.user_id})"
-
-
+
# 构建prompt
with Timer() as t_build_prompt:
prompt = await prompt_builder.build_prompt(
diff --git a/src/chat/normal_chat/normal_chat_utils.py b/src/chat/normal_chat/normal_chat_utils.py
index 2eae4829..2ebd3bda 100644
--- a/src/chat/normal_chat/normal_chat_utils.py
+++ b/src/chat/normal_chat/normal_chat_utils.py
@@ -27,7 +27,4 @@ def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
bot_filter["user_id"] = bot_id
bot_reply_count = count_messages(bot_filter)
- return {
- "bot_reply_count": bot_reply_count,
- "total_message_count": total_message_count
- }
\ No newline at end of file
+ return {"bot_reply_count": bot_reply_count, "total_message_count": total_message_count}
diff --git a/src/chat/normal_chat/normal_prompt.py b/src/chat/normal_chat/normal_prompt.py
index 365d42de..54624a02 100644
--- a/src/chat/normal_chat/normal_prompt.py
+++ b/src/chat/normal_chat/normal_prompt.py
@@ -17,7 +17,6 @@ logger = get_logger("prompt")
def init_prompt():
-
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("在群里聊天", "chat_target_group2")
diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py
index 38c9e9a8..66b58e25 100644
--- a/src/chat/utils/chat_message_builder.py
+++ b/src/chat/utils/chat_message_builder.py
@@ -10,6 +10,7 @@ from rich.traceback import install
install(extra_lines=3)
+
def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
@@ -198,7 +199,7 @@ async def _build_readable_messages_internal(
content = msg.get("display_message")
else:
content = msg.get("processed_plain_text", "") # 默认空字符串
-
+
if "ᶠ" in content:
content = content.replace("ᶠ", "")
if "ⁿ" in content:
@@ -465,7 +466,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
content = msg.get("display_message")
else:
content = msg.get("processed_plain_text", "")
-
+
if "ᶠ" in content:
content = content.replace("ᶠ", "")
if "ⁿ" in content:
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index e353500f..8f98256e 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -59,7 +59,7 @@ class ChatConfig(ConfigBase):
chat_mode: str = "normal"
"""聊天模式"""
-
+
auto_focus_threshold: float = 1.0
"""自动切换到专注聊天的阈值,越低越容易进入专注聊天"""
@@ -132,8 +132,6 @@ class NormalChatConfig(ConfigBase):
class FocusChatConfig(ConfigBase):
"""专注聊天配置类"""
-
-
observation_context_size: int = 12
"""可观察到的最长上下文大小,超过这个值的上下文会被压缩"""
@@ -346,7 +344,7 @@ class TelemetryConfig(ConfigBase):
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
- debug_show_chat_mode: bool = True
+ debug_show_chat_mode: bool = False
"""是否在回复后显示当前聊天模式"""
enable_friend_chat: bool = False
diff --git a/src/plugins/test_plugin/actions/__init__.py b/src/plugins/test_plugin/actions/__init__.py
index 6797e51b..7d96ea8a 100644
--- a/src/plugins/test_plugin/actions/__init__.py
+++ b/src/plugins/test_plugin/actions/__init__.py
@@ -2,5 +2,6 @@
# 导入所有动作模块以确保装饰器被执行
from . import test_action # noqa
+
# from . import online_action # noqa
from . import mute_action # noqa
diff --git a/src/plugins/test_plugin/actions/mute_action.py b/src/plugins/test_plugin/actions/mute_action.py
index 4f457d90..df112a16 100644
--- a/src/plugins/test_plugin/actions/mute_action.py
+++ b/src/plugins/test_plugin/actions/mute_action.py
@@ -57,14 +57,15 @@ class MuteAction(PluginAction):
# 确保duration是字符串类型
if int(duration) < 60:
duration = 60
- if int(duration) > 3600*24*30:
- duration = 3600*24*30
+ if int(duration) > 3600 * 24 * 30:
+ duration = 3600 * 24 * 30
duration_str = str(int(duration))
# 发送群聊禁言命令,按照新格式
await self.send_message(
- type = "command", data = {"name": "GROUP_BAN", "args": {"qq_id": str(user_id), "duration": duration_str}},
- display_message = f"我 禁言了 {target} {duration_str}秒"
+ type="command",
+ data={"name": "GROUP_BAN", "args": {"qq_id": str(user_id), "duration": duration_str}},
+ display_message=f"我 禁言了 {target} {duration_str}秒",
)
logger.info(f"{self.log_prefix} 成功发送禁言命令,用户 {target}({user_id}),时长 {duration} 秒")
diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml
index 946e3374..2d4eccb9 100644
--- a/template/bot_config_template.toml
+++ b/template/bot_config_template.toml
@@ -335,7 +335,7 @@ key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效
enable = true
[experimental] #实验性功能
-debug_show_chat_mode = true # 是否在回复后显示当前聊天模式
+debug_show_chat_mode = false # 是否在回复后显示当前聊天模式
enable_friend_chat = false # 是否启用好友聊天
pfc_chatting = false # 暂时无效
From 859e5201fce76120cb12a303ed770443144ea36f Mon Sep 17 00:00:00 2001
From: SengokuCola <1026294844@qq.com>
Date: Wed, 28 May 2025 20:45:34 +0800
Subject: [PATCH 07/13] Update default_expressor.py
---
src/chat/focus_chat/expressors/default_expressor.py | 7 -------
1 file changed, 7 deletions(-)
diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py
index 382d20aa..eceadf67 100644
--- a/src/chat/focus_chat/expressors/default_expressor.py
+++ b/src/chat/focus_chat/expressors/default_expressor.py
@@ -277,15 +277,8 @@ class DefaultExpressor:
in_mind_reply,
target_message,
) -> str:
- # prompt_personality = individuality.get_prompt(x_person=0, level=2)
-
- # Determine if it's a group chat
is_group_chat = bool(chat_stream.group_info)
- # Use sender_name passed from caller for private chat, otherwise use a default for group
- # Default sender_name for group chat isn't used in the group prompt template, but set for consistency
- # effective_sender_name = sender_name if not is_group_chat else "某人"
-
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
From 184a74bb8b18d6df2775bb59cfd80b07a07e8724 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com>
Date: Wed, 28 May 2025 20:50:18 +0800
Subject: [PATCH 08/13] 1
---
mongodb_to_sqlite .bat | 142 ++++++++++++++++++++---------------------
1 file changed, 71 insertions(+), 71 deletions(-)
diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite .bat
index 5c029e44..f960e508 100644
--- a/mongodb_to_sqlite .bat
+++ b/mongodb_to_sqlite .bat
@@ -1,72 +1,72 @@
-@echo off
-CHCP 65001 > nul
-setlocal enabledelayedexpansion
-
-echo 你需要选择启动方式,输入字母来选择:
-echo V = 不知道什么意思就输入 V
-echo C = 输入 C 使用 Conda 环境
-echo.
-choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V
-
-set "ENV_TYPE="
-if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA"
-if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV"
-
-if "%ENV_TYPE%" == "CONDA" goto activate_conda
-if "%ENV_TYPE%" == "VENV" goto activate_venv
-
-REM 如果 choice 超时或返回意外值,默认使用 venv
-echo WARN: Invalid selection or timeout from choice. Defaulting to VENV.
-set "ENV_TYPE=VENV"
-goto activate_venv
-
-:activate_conda
- set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: "
- if not defined CONDA_ENV_NAME (
- echo 错误: 未输入 Conda 环境名称.
- pause
- exit /b 1
- )
- echo 选择: Conda '!CONDA_ENV_NAME!'
- REM 激活Conda环境
- call conda activate !CONDA_ENV_NAME!
- if !ERRORLEVEL! neq 0 (
- echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在.
- pause
- exit /b 1
- )
- goto env_activated
-
-:activate_venv
- echo Selected: venv (default or selected)
- REM 查找venv虚拟环境
- set "venv_path=%~dp0venv\Scripts\activate.bat"
- if not exist "%venv_path%" (
- echo Error: venv not found. Ensure the venv directory exists alongside the script.
- pause
- exit /b 1
- )
- REM 激活虚拟环境
- call "%venv_path%"
- if %ERRORLEVEL% neq 0 (
- echo Error: Failed to activate venv virtual environment.
- pause
- exit /b 1
- )
- goto env_activated
-
-:env_activated
-echo Environment activated successfully!
-
-REM --- 后续脚本执行 ---
-
-REM 运行预处理脚本
-python "%~dp0scripts\mongodb_to_sqlite.py"
-if %ERRORLEVEL% neq 0 (
- echo Error: mongodb_to_sqlite.py execution failed.
- pause
- exit /b 1
-)
-
-echo All processing steps completed!
+@echo off
+CHCP 65001 > nul
+setlocal enabledelayedexpansion
+
+echo 你需要选择启动方式,输入字母来选择:
+echo V = 不知道什么意思就输入 V
+echo C = 输入 C 使用 Conda 环境
+echo.
+choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V
+
+set "ENV_TYPE="
+if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA"
+if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV"
+
+if "%ENV_TYPE%" == "CONDA" goto activate_conda
+if "%ENV_TYPE%" == "VENV" goto activate_venv
+
+REM 如果 choice 超时或返回意外值,默认使用 venv
+echo WARN: Invalid selection or timeout from choice. Defaulting to VENV.
+set "ENV_TYPE=VENV"
+goto activate_venv
+
+:activate_conda
+ set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: "
+ if not defined CONDA_ENV_NAME (
+ echo 错误: 未输入 Conda 环境名称.
+ pause
+ exit /b 1
+ )
+ echo 选择: Conda '!CONDA_ENV_NAME!'
+ REM 激活Conda环境
+ call conda activate !CONDA_ENV_NAME!
+ if !ERRORLEVEL! neq 0 (
+ echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在.
+ pause
+ exit /b 1
+ )
+ goto env_activated
+
+:activate_venv
+ echo Selected: venv (default or selected)
+ REM 查找venv虚拟环境
+ set "venv_path=%~dp0venv\Scripts\activate.bat"
+ if not exist "%venv_path%" (
+ echo Error: venv not found. Ensure the venv directory exists alongside the script.
+ pause
+ exit /b 1
+ )
+ REM 激活虚拟环境
+ call "%venv_path%"
+ if %ERRORLEVEL% neq 0 (
+ echo Error: Failed to activate venv virtual environment.
+ pause
+ exit /b 1
+ )
+ goto env_activated
+
+:env_activated
+echo Environment activated successfully!
+
+REM --- 后续脚本执行 ---
+
+REM 运行预处理脚本
+python "%~dp0scripts\mongodb_to_sqlite.py"
+if %ERRORLEVEL% neq 0 (
+ echo Error: mongodb_to_sqlite.py execution failed.
+ pause
+ exit /b 1
+)
+
+echo All processing steps completed!
pause
\ No newline at end of file
From 460c7fb75ac839737f6cf7afc443d96a4b350271 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com>
Date: Wed, 28 May 2025 20:51:38 +0800
Subject: [PATCH 09/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=BF=81=E7=A7=BB?=
=?UTF-8?q?=E8=84=9A=E6=9C=AC=EF=BC=8C=E6=94=AF=E6=8C=81=E6=96=AD=E7=82=B9?=
=?UTF-8?q?=E7=BB=AD=E4=BC=A0=E5=92=8C=E6=89=B9=E9=87=8F=E5=AF=BC=E5=85=A5?=
=?UTF-8?q?=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=95=B0=E6=8D=AE=E5=BA=93=E6=A8=A1?=
=?UTF-8?q?=E5=9E=8B=EF=BC=8C=E5=85=81=E8=AE=B8=E7=BE=A4=E8=81=8A=E4=BF=A1?=
=?UTF-8?q?=E6=81=AF=E5=AD=97=E6=AE=B5=E4=B8=BA=E5=8F=AF=E9=80=89=EF=BC=88?=
=?UTF-8?q?null=3DTrue=EF=BC=89?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
scripts/mongodb_to_sqlite.py | 485 ++++++++++++++++++++------
src/common/database/database_model.py | 6 +-
2 files changed, 374 insertions(+), 117 deletions(-)
diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py
index 5ff346af..609906fa 100644
--- a/scripts/mongodb_to_sqlite.py
+++ b/scripts/mongodb_to_sqlite.py
@@ -1,6 +1,9 @@
import os
import json
import sys # 新增系统模块导入
+# import time
+import pickle
+from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from typing import Dict, Any, List, Optional, Type
@@ -10,16 +13,31 @@ from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from peewee import Model, Field, IntegrityError
+# Rich 进度条和显示组件
+from rich.console import Console
+from rich.progress import (
+ Progress,
+ TextColumn,
+ BarColumn,
+ TaskProgressColumn,
+ TimeRemainingColumn,
+ TimeElapsedColumn,
+ SpinnerColumn
+)
+from rich.table import Table
+from rich.panel import Panel
+# from rich.text import Text
+
from src.common.database.database import db
from src.common.database.database_model import (
ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions,
- OnlineTime, PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges
+ PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges
)
from src.common.logger_manager import get_logger
logger = get_logger("mongodb_to_sqlite")
-
+ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@dataclass
class MigrationConfig:
"""迁移配置类"""
@@ -32,6 +50,19 @@ class MigrationConfig:
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
+# 数据验证相关类已移除 - 用户要求不要数据验证
+
+
+@dataclass
+class MigrationCheckpoint:
+ """迁移断点数据"""
+ collection_name: str
+ processed_count: int
+ last_processed_id: Any
+ timestamp: datetime
+ batch_errors: List[Dict[str, Any]] = field(default_factory=list)
+
+
@dataclass
class MigrationStats:
"""迁移统计信息"""
@@ -41,7 +72,11 @@ class MigrationStats:
error_count: int = 0
skipped_count: int = 0
duplicate_count: int = 0
+ validation_errors: int = 0
+ batch_insert_count: int = 0
errors: List[Dict[str, Any]] = field(default_factory=list)
+ start_time: Optional[datetime] = None
+ end_time: Optional[datetime] = None
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
"""添加错误记录"""
@@ -52,6 +87,11 @@ class MigrationStats:
'doc_data': doc_data
})
self.error_count += 1
+
+ def add_validation_error(self, doc_id: Any, field: str, error: str):
+ """添加验证错误"""
+ self.add_error(doc_id, f"验证失败 - {field}: {error}")
+ self.validation_errors += 1
class MongoToSQLiteMigrator:
@@ -65,6 +105,15 @@ class MongoToSQLiteMigrator:
# 迁移配置
self.migration_configs = self._initialize_migration_configs()
+
+ # 进度条控制台
+ self.console = Console()
+ # 检查点目录
+ self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
+ self.checkpoint_dir.mkdir(exist_ok=True)
+
+ # 验证规则已禁用
+ self.validation_rules = self._initialize_validation_rules()
def _build_mongo_uri(self) -> str:
"""构建MongoDB连接URI"""
@@ -84,8 +133,7 @@ class MongoToSQLiteMigrator:
def _initialize_migration_configs(self) -> List[MigrationConfig]:
"""初始化迁移配置"""
- return [
- # 表情包迁移配置
+ return [ # 表情包迁移配置
MigrationConfig(
mongo_collection="emoji",
target_model=Emoji,
@@ -96,13 +144,13 @@ class MongoToSQLiteMigrator:
"description": "description",
"emotion": "emotion",
"usage_count": "usage_count",
- "last_used_time": "last_used_time",
- "last_used_time": "record_time" # 这个纯粹是为了应付整体映射格式,实际上直接用当前时间戳填了record_time
+ "last_used_time": "last_used_time"
+ # record_time字段将在转换时自动设置为当前时间
},
+ enable_validation=False, # 禁用数据验证
unique_fields=["full_path", "emoji_hash"]
),
-
- # 聊天流迁移配置
+ # 聊天流迁移配置
MigrationConfig(
mongo_collection="chat_streams",
target_model=ChatStreams,
@@ -119,10 +167,10 @@ class MongoToSQLiteMigrator:
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname"
},
+ enable_validation=False, # 禁用数据验证
unique_fields=["stream_id"]
),
-
- # LLM使用记录迁移配置
+ # LLM使用记录迁移配置
MigrationConfig(
mongo_collection="llm_usage",
target_model=LLMUsage,
@@ -138,10 +186,10 @@ class MongoToSQLiteMigrator:
"status": "status",
"timestamp": "timestamp"
},
+ enable_validation=False, # 禁用数据验证
unique_fields=["user_id", "timestamp"] # 组合唯一性
),
-
- # 消息迁移配置
+ # 消息迁移配置
MigrationConfig(
mongo_collection="messages",
target_model=Messages,
@@ -168,6 +216,7 @@ class MongoToSQLiteMigrator:
"detailed_plain_text": "detailed_plain_text",
"memorized_times": "memorized_times"
},
+ enable_validation=False, # 禁用数据验证
unique_fields=["message_id"]
),
@@ -194,21 +243,7 @@ class MongoToSQLiteMigrator:
"hash": "image_description_hash",
"description": "description",
"timestamp": "timestamp"
- },
- unique_fields=["image_description_hash", "type"]
- ),
-
- # 在线时长迁移配置
- MigrationConfig(
- mongo_collection="online_time",
- target_model=OnlineTime,
- field_mapping={
- "timestamp": "timestamp",
- "duration": "duration",
- "start_timestamp": "start_timestamp",
- "end_timestamp": "end_timestamp"
- },
- unique_fields=["start_timestamp", "end_timestamp"]
+ }, unique_fields=["image_description_hash", "type"]
),
# 个人信息迁移配置
@@ -290,6 +325,9 @@ class MongoToSQLiteMigrator:
unique_fields=["source", "target"] # 组合唯一性
)
]
+ def _initialize_validation_rules(self) -> Dict[str, Any]:
+ """数据验证已禁用 - 返回空字典"""
+ return {}
def connect_mongodb(self) -> bool:
"""连接到MongoDB"""
@@ -412,11 +450,69 @@ class MongoToSQLiteMigrator:
elif field_type in ["FloatField", "DoubleField"]:
return 0.0
elif field_type == "BooleanField":
- return False
+ return False
elif field_type == "DateTimeField":
return datetime.now()
return None
+ def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
+ """数据验证已禁用 - 始终返回True"""
+ return True
+
+ def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
+ """保存迁移断点"""
+ checkpoint = MigrationCheckpoint(
+ collection_name=collection_name,
+ processed_count=processed_count,
+ last_processed_id=last_id,
+ timestamp=datetime.now()
+ )
+
+ checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
+ try:
+ with open(checkpoint_file, 'wb') as f:
+ pickle.dump(checkpoint, f)
+ except Exception as e:
+ logger.warning(f"保存断点失败: {e}")
+
+ def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
+ """加载迁移断点"""
+ checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
+ if not checkpoint_file.exists():
+ return None
+
+ try:
+ with open(checkpoint_file, 'rb') as f:
+ return pickle.load(f)
+ except Exception as e:
+ logger.warning(f"加载断点失败: {e}")
+ return None
+
+ def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
+ """批量插入数据"""
+ if not data_list:
+ return 0
+
+ success_count = 0
+ try:
+ with db.atomic():
+ # 分批插入,避免SQL语句过长
+ batch_size = 100
+ for i in range(0, len(data_list), batch_size):
+ batch = data_list[i:i + batch_size]
+ model.insert_many(batch).execute()
+ success_count += len(batch)
+ except Exception as e:
+ logger.error(f"批量插入失败: {e}")
+ # 如果批量插入失败,尝试逐个插入
+ for data in data_list:
+ try:
+ model.create(**data)
+ success_count += 1
+ except Exception:
+ pass # 忽略单个插入失败
+
+ return success_count
def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any],
unique_fields: List[str]) -> bool:
@@ -458,17 +554,30 @@ class MongoToSQLiteMigrator:
except Exception as e:
logger.error(f"创建模型实例失败: {e}")
return None
-
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
- """迁移单个集合 - 使用ORM方式"""
+ """迁移单个集合 - 使用优化的批量插入和进度条"""
stats = MigrationStats()
+ stats.start_time = datetime.now()
+
+ # 检查是否有断点
+ checkpoint = self._load_checkpoint(config.mongo_collection)
+ start_from_id = checkpoint.last_processed_id if checkpoint else None
+ if checkpoint:
+ stats.processed_count = checkpoint.processed_count
+ logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
try:
# 获取MongoDB集合
mongo_collection = self.mongo_db[config.mongo_collection]
- stats.total_documents = mongo_collection.count_documents({})
+
+ # 构建查询条件(用于断点恢复)
+ query = {}
+ if start_from_id:
+ query = {"_id": {"$gt": start_from_id}}
+
+ stats.total_documents = mongo_collection.count_documents(query)
if stats.total_documents == 0:
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
@@ -476,69 +585,112 @@ class MongoToSQLiteMigrator:
logger.info(f"待迁移文档数量: {stats.total_documents}")
- # 逐个处理文档
- batch_count = 0
- for mongo_doc in mongo_collection.find().batch_size(config.batch_size):
- try:
- stats.processed_count += 1
- doc_id = mongo_doc.get('_id', 'unknown')
-
- # 构建目标数据
- target_data = {}
- for mongo_field, sqlite_field in config.field_mapping.items():
- value = self._get_nested_value(mongo_doc, mongo_field)
-
- # 获取目标字段对象并转换类型
- if hasattr(config.target_model, sqlite_field):
- field_obj = getattr(config.target_model, sqlite_field)
- converted_value = self._convert_field_value(value, field_obj)
- target_data[sqlite_field] = converted_value
-
- # 重复检查
- if config.skip_duplicates and self._check_duplicate_by_unique_fields(
- config.target_model, target_data, config.unique_fields
- ):
- stats.duplicate_count += 1
- stats.skipped_count += 1
- logger.debug(f"跳过重复记录: {doc_id}")
- continue
-
- # 使用ORM创建实例
- with db.atomic(): # 每个实例的事务保护
- instance = self._create_model_instance(config.target_model, target_data)
-
- if instance:
- stats.success_count += 1
- else:
- stats.add_error(doc_id, "ORM创建实例失败", target_data)
+ # 创建Rich进度条
+ with Progress(
+ SpinnerColumn(),
+ TextColumn("[progress.description]{task.description}"),
+ BarColumn(),
+ TaskProgressColumn(),
+ TimeElapsedColumn(),
+ TimeRemainingColumn(),
+ console=self.console,
+ refresh_per_second=10
+ ) as progress:
- except Exception as e:
- doc_id = mongo_doc.get('_id', 'unknown')
- stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
- logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
+ task = progress.add_task(
+ f"迁移 {config.mongo_collection}",
+ total=stats.total_documents
+ )
+ # 批量处理数据
+ batch_data = []
+ batch_count = 0
+ last_processed_id = None
- # 进度报告
- batch_count += 1
- if batch_count % config.batch_size == 0:
- progress = (stats.processed_count / stats.total_documents) * 100
- logger.info(
- f"迁移进度: {stats.processed_count}/{stats.total_documents} "
- f"({progress:.1f}%) - 成功: {stats.success_count}, "
- f"错误: {stats.error_count}, 跳过: {stats.skipped_count}"
- )
+ for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
+ try:
+ doc_id = mongo_doc.get('_id', 'unknown')
+ last_processed_id = doc_id
+
+ # 构建目标数据
+ target_data = {}
+ for mongo_field, sqlite_field in config.field_mapping.items():
+ value = self._get_nested_value(mongo_doc, mongo_field)
+
+ # 获取目标字段对象并转换类型
+ if hasattr(config.target_model, sqlite_field):
+ field_obj = getattr(config.target_model, sqlite_field)
+ converted_value = self._convert_field_value(value, field_obj)
+ target_data[sqlite_field] = converted_value
+
+ # 数据验证已禁用
+ # if config.enable_validation:
+ # if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
+ # stats.skipped_count += 1
+ # continue
+
+ # 重复检查
+ if config.skip_duplicates and self._check_duplicate_by_unique_fields(
+ config.target_model, target_data, config.unique_fields
+ ):
+ stats.duplicate_count += 1
+ stats.skipped_count += 1
+ logger.debug(f"跳过重复记录: {doc_id}")
+ continue
+
+ # 添加到批量数据
+ batch_data.append(target_data)
+ stats.processed_count += 1
+
+ # 执行批量插入
+ if len(batch_data) >= config.batch_size:
+ success_count = self._batch_insert(config.target_model, batch_data)
+ stats.success_count += success_count
+ stats.batch_insert_count += 1
+
+ # 保存断点
+ self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
+
+ batch_data.clear()
+ batch_count += 1
+
+ # 更新进度条
+ progress.update(task, advance=config.batch_size)
+
+ except Exception as e:
+ doc_id = mongo_doc.get('_id', 'unknown')
+ stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
+ logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
+
+ # 处理剩余的批量数据
+ if batch_data:
+ success_count = self._batch_insert(config.target_model, batch_data)
+ stats.success_count += success_count
+ stats.batch_insert_count += 1
+ progress.update(task, advance=len(batch_data))
+
+ # 完成进度条
+ progress.update(task, completed=stats.total_documents)
+
+ stats.end_time = datetime.now()
+ duration = stats.end_time - stats.start_time
logger.info(
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
- f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}"
+ f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
+ f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
)
+ # 清理断点文件
+ checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
+ if checkpoint_file.exists():
+ checkpoint_file.unlink()
+
except Exception as e:
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
stats.add_error("collection_error", str(e))
return stats
-
def migrate_all(self) -> Dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
@@ -550,18 +702,44 @@ class MongoToSQLiteMigrator:
all_stats = {}
try:
- for config in self.migration_configs:
- logger.info(f"\n开始处理集合: {config.mongo_collection}")
+ # 创建总体进度表格
+ total_collections = len(self.migration_configs)
+ self.console.print(Panel(
+ f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
+ f"[yellow]总集合数: {total_collections}[/yellow]",
+ title="迁移开始",
+ expand=False
+ ))
+ for idx, config in enumerate(self.migration_configs, 1):
+ self.console.print(f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]")
stats = self.migrate_collection(config)
all_stats[config.mongo_collection] = stats
+ # 显示单个集合的快速统计
+ if stats.processed_count > 0:
+ success_rate = (stats.success_count / stats.processed_count * 100)
+ if success_rate >= 95:
+ status_emoji = "✅"
+ status_color = "bright_green"
+ elif success_rate >= 80:
+ status_emoji = "⚠️"
+ status_color = "yellow"
+ else:
+ status_emoji = "❌"
+ status_color = "red"
+
+ self.console.print(
+ f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
+ f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
+ )
+
# 错误率检查
if stats.processed_count > 0:
error_rate = stats.error_count / stats.processed_count
if error_rate > 0.1: # 错误率超过10%
- logger.warning(
- f"集合 {config.mongo_collection} 错误率较高: {error_rate:.1%} "
- f"({stats.error_count}/{stats.processed_count})"
+ self.console.print(
+ f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
+ f"({stats.error_count}/{stats.processed_count})[/red]"
)
finally:
@@ -571,52 +749,131 @@ class MongoToSQLiteMigrator:
return all_stats
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
- """打印迁移汇总信息"""
- logger.info("\n" + "="*60)
- logger.info("数据迁移汇总报告")
- logger.info("="*60)
-
+ """使用Rich打印美观的迁移汇总信息"""
+ # 计算总体统计
total_processed = sum(stats.processed_count for stats in all_stats.values())
total_success = sum(stats.success_count for stats in all_stats.values())
total_errors = sum(stats.error_count for stats in all_stats.values())
total_skipped = sum(stats.skipped_count for stats in all_stats.values())
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
+ total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
+ total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
- # 表头
- logger.info(f"{'集合名称':<20} | {'处理':<6} | {'成功':<6} | {'错误':<6} | {'跳过':<6} | {'重复':<6} | {'成功率':<8}")
- logger.info("-" * 75)
+ # 计算总耗时
+ total_duration_seconds = 0
+ for stats in all_stats.values():
+ if stats.start_time and stats.end_time:
+ duration = stats.end_time - stats.start_time
+ total_duration_seconds += duration.total_seconds()
+
+ # 创建详细统计表格
+ table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
+ table.add_column("集合名称", style="cyan", width=20)
+ table.add_column("文档总数", justify="right", style="blue")
+ table.add_column("处理数量", justify="right", style="green")
+ table.add_column("成功数量", justify="right", style="green")
+ table.add_column("错误数量", justify="right", style="red")
+ table.add_column("跳过数量", justify="right", style="yellow")
+ table.add_column("重复数量", justify="right", style="bright_yellow")
+ table.add_column("验证错误", justify="right", style="red")
+ table.add_column("批次数", justify="right", style="purple")
+ table.add_column("成功率", justify="right", style="bright_green")
+ table.add_column("耗时(秒)", justify="right", style="blue")
for collection_name, stats in all_stats.items():
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
- logger.info(
- f"{collection_name:<20} | "
- f"{stats.processed_count:<6} | "
- f"{stats.success_count:<6} | "
- f"{stats.error_count:<6} | "
- f"{stats.skipped_count:<6} | "
- f"{stats.duplicate_count:<6} | "
- f"{success_rate:<7.1f}%"
+ duration = 0
+ if stats.start_time and stats.end_time:
+ duration = (stats.end_time - stats.start_time).total_seconds()
+
+ # 根据成功率设置颜色
+ if success_rate >= 95:
+ success_rate_style = "[bright_green]"
+ elif success_rate >= 80:
+ success_rate_style = "[yellow]"
+ else:
+ success_rate_style = "[red]"
+
+ table.add_row(
+ collection_name,
+ str(stats.total_documents),
+ str(stats.processed_count),
+ str(stats.success_count),
+ f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0",
+ f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0",
+ f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0",
+ f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
+ str(stats.batch_insert_count),
+ f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
+ f"{duration:.2f}"
)
- logger.info("-" * 75)
+ # 添加总计行
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
- logger.info(
- f"{'总计':<20} | "
- f"{total_processed:<6} | "
- f"{total_success:<6} | "
- f"{total_errors:<6} | "
- f"{total_skipped:<6} | "
- f"{total_duplicates:<6} | "
- f"{total_success_rate:<7.1f}%"
+ if total_success_rate >= 95:
+ total_rate_style = "[bright_green]"
+ elif total_success_rate >= 80:
+ total_rate_style = "[yellow]"
+ else:
+ total_rate_style = "[red]"
+
+ table.add_section()
+ table.add_row(
+ "[bold]总计[/bold]",
+ f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]",
+ f"[bold]{total_processed}[/bold]",
+ f"[bold]{total_success}[/bold]",
+ f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
+ f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
+ f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" if total_duplicates > 0 else "[bold]0[/bold]",
+ f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
+ f"[bold]{total_batch_inserts}[/bold]",
+ f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
+ f"[bold]{total_duration_seconds:.2f}[/bold]"
)
+ self.console.print(table)
+
+ # 创建状态面板
+ status_items = []
if total_errors > 0:
- logger.warning(f"\n⚠️ 存在 {total_errors} 个错误,请检查日志详情")
+ status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
+
+ if total_validation_errors > 0:
+ status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
if total_duplicates > 0:
- logger.info(f"ℹ️ 跳过了 {total_duplicates} 个重复记录")
+ status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
- logger.info("="*60)
+ if total_success_rate >= 95:
+ status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
+ elif total_success_rate >= 80:
+ status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
+ else:
+ status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
+
+ if status_items:
+ status_panel = Panel(
+ "\n".join(status_items),
+ title="[bold yellow]迁移状态总结[/bold yellow]",
+ border_style="yellow"
+ )
+ self.console.print(status_panel)
+
+ # 性能统计面板
+ avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
+ performance_info = (
+ f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n"
+ f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
+ f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
+ )
+
+ performance_panel = Panel(
+ performance_info,
+ title="[bold green]性能统计[/bold green]",
+ border_style="green"
+ )
+ self.console.print(performance_panel)
def add_migration_config(self, config: MigrationConfig):
"""添加新的迁移配置"""
diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py
index ccf78964..bd264637 100644
--- a/src/common/database/database_model.py
+++ b/src/common/database/database_model.py
@@ -44,9 +44,9 @@ class ChatStreams(BaseModel):
# platform: "qq"
# group_id: "941657197"
# group_name: "测试"
- group_platform = TextField()
- group_id = TextField()
- group_name = TextField()
+ group_platform = TextField(null=True) # 群聊信息可能不存在
+ group_id = TextField(null=True)
+ group_name = TextField(null=True)
# last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位)
last_active_time = DoubleField()
From 847bd23a62e602b98b708dce293da62ad9873f24 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
Date: Wed, 28 May 2025 12:51:51 +0000
Subject: [PATCH 10/13] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?=
=?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
scripts/mongodb_to_sqlite.py | 443 +++++++++++++++++------------------
1 file changed, 220 insertions(+), 223 deletions(-)
diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py
index 609906fa..0f292fef 100644
--- a/scripts/mongodb_to_sqlite.py
+++ b/scripts/mongodb_to_sqlite.py
@@ -1,6 +1,7 @@
import os
import json
import sys # 新增系统模块导入
+
# import time
import pickle
from pathlib import Path
@@ -16,13 +17,13 @@ from peewee import Model, Field, IntegrityError
# Rich 进度条和显示组件
from rich.console import Console
from rich.progress import (
- Progress,
- TextColumn,
- BarColumn,
- TaskProgressColumn,
+ Progress,
+ TextColumn,
+ BarColumn,
+ TaskProgressColumn,
TimeRemainingColumn,
TimeElapsedColumn,
- SpinnerColumn
+ SpinnerColumn,
)
from rich.table import Table
from rich.panel import Panel
@@ -30,17 +31,29 @@ from rich.panel import Panel
from src.common.database.database import db
from src.common.database.database_model import (
- ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions,
- PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges
+ ChatStreams,
+ LLMUsage,
+ Emoji,
+ Messages,
+ Images,
+ ImageDescriptions,
+ PersonInfo,
+ Knowledges,
+ ThinkingLog,
+ GraphNodes,
+ GraphEdges,
)
from src.common.logger_manager import get_logger
logger = get_logger("mongodb_to_sqlite")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+
+
@dataclass
class MigrationConfig:
"""迁移配置类"""
+
mongo_collection: str
target_model: Type[Model]
field_mapping: Dict[str, str]
@@ -56,6 +69,7 @@ class MigrationConfig:
@dataclass
class MigrationCheckpoint:
"""迁移断点数据"""
+
collection_name: str
processed_count: int
last_processed_id: Any
@@ -66,6 +80,7 @@ class MigrationCheckpoint:
@dataclass
class MigrationStats:
"""迁移统计信息"""
+
total_documents: int = 0
processed_count: int = 0
success_count: int = 0
@@ -77,17 +92,14 @@ class MigrationStats:
errors: List[Dict[str, Any]] = field(default_factory=list)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
-
+
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
"""添加错误记录"""
- self.errors.append({
- 'doc_id': str(doc_id),
- 'error': error,
- 'timestamp': datetime.now().isoformat(),
- 'doc_data': doc_data
- })
+ self.errors.append(
+ {"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
+ )
self.error_count += 1
-
+
def add_validation_error(self, doc_id: Any, field: str, error: str):
"""添加验证错误"""
self.add_error(doc_id, f"验证失败 - {field}: {error}")
@@ -96,81 +108,81 @@ class MigrationStats:
class MongoToSQLiteMigrator:
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
-
+
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
self.mongo_uri = mongo_uri or self._build_mongo_uri()
self.mongo_client: Optional[MongoClient] = None
self.mongo_db = None
-
+
# 迁移配置
self.migration_configs = self._initialize_migration_configs()
-
+
# 进度条控制台
self.console = Console()
- # 检查点目录
+ # 检查点目录
self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
self.checkpoint_dir.mkdir(exist_ok=True)
-
+
# 验证规则已禁用
self.validation_rules = self._initialize_validation_rules()
-
+
def _build_mongo_uri(self) -> str:
"""构建MongoDB连接URI"""
if mongo_uri := os.getenv("MONGODB_URI"):
return mongo_uri
-
- user = os.getenv('MONGODB_USER')
- password = os.getenv('MONGODB_PASS')
- host = os.getenv('MONGODB_HOST', 'localhost')
- port = os.getenv('MONGODB_PORT', '27017')
- auth_source = os.getenv('MONGODB_AUTH_SOURCE', 'admin')
-
+
+ user = os.getenv("MONGODB_USER")
+ password = os.getenv("MONGODB_PASS")
+ host = os.getenv("MONGODB_HOST", "localhost")
+ port = os.getenv("MONGODB_PORT", "27017")
+ auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
+
if user and password:
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
else:
return f"mongodb://{host}:{port}/{self.database_name}"
-
+
def _initialize_migration_configs(self) -> List[MigrationConfig]:
"""初始化迁移配置"""
- return [ # 表情包迁移配置
+ return [ # 表情包迁移配置
MigrationConfig(
mongo_collection="emoji",
target_model=Emoji,
field_mapping={
"full_path": "full_path",
- "format": "format",
+ "format": "format",
"hash": "emoji_hash",
"description": "description",
"emotion": "emotion",
"usage_count": "usage_count",
- "last_used_time": "last_used_time"
+ "last_used_time": "last_used_time",
# record_time字段将在转换时自动设置为当前时间
},
enable_validation=False, # 禁用数据验证
- unique_fields=["full_path", "emoji_hash"]
+ unique_fields=["full_path", "emoji_hash"],
),
- # 聊天流迁移配置
+ # 聊天流迁移配置
MigrationConfig(
mongo_collection="chat_streams",
target_model=ChatStreams,
field_mapping={
"stream_id": "stream_id",
"create_time": "create_time",
- "group_info.platform": "group_platform",# 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。
+ "group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。
"group_info.group_id": "group_id", # 同上
- "group_info.group_name": "group_name", # 同上
+ "group_info.group_name": "group_name", # 同上
"last_active_time": "last_active_time",
"platform": "platform",
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
- "user_info.user_cardname": "user_cardname"
+ "user_info.user_cardname": "user_cardname",
},
enable_validation=False, # 禁用数据验证
- unique_fields=["stream_id"]
+ unique_fields=["stream_id"],
),
- # LLM使用记录迁移配置
+ # LLM使用记录迁移配置
MigrationConfig(
mongo_collection="llm_usage",
target_model=LLMUsage,
@@ -184,12 +196,12 @@ class MongoToSQLiteMigrator:
"total_tokens": "total_tokens",
"cost": "cost",
"status": "status",
- "timestamp": "timestamp"
+ "timestamp": "timestamp",
},
enable_validation=False, # 禁用数据验证
- unique_fields=["user_id", "timestamp"] # 组合唯一性
+ unique_fields=["user_id", "timestamp"], # 组合唯一性
),
- # 消息迁移配置
+ # 消息迁移配置
MigrationConfig(
mongo_collection="messages",
target_model=Messages,
@@ -214,12 +226,11 @@ class MongoToSQLiteMigrator:
"user_info.user_cardname": "user_cardname",
"processed_plain_text": "processed_plain_text",
"detailed_plain_text": "detailed_plain_text",
- "memorized_times": "memorized_times"
+ "memorized_times": "memorized_times",
},
enable_validation=False, # 禁用数据验证
- unique_fields=["message_id"]
+ unique_fields=["message_id"],
),
-
# 图片迁移配置
MigrationConfig(
mongo_collection="images",
@@ -229,11 +240,10 @@ class MongoToSQLiteMigrator:
"description": "description",
"path": "path",
"timestamp": "timestamp",
- "type": "type"
+ "type": "type",
},
- unique_fields=["path"]
+ unique_fields=["path"],
),
-
# 图片描述迁移配置
MigrationConfig(
mongo_collection="image_descriptions",
@@ -242,10 +252,10 @@ class MongoToSQLiteMigrator:
"type": "type",
"hash": "image_description_hash",
"description": "description",
- "timestamp": "timestamp"
- }, unique_fields=["image_description_hash", "type"]
+ "timestamp": "timestamp",
+ },
+ unique_fields=["image_description_hash", "type"],
),
-
# 个人信息迁移配置
MigrationConfig(
mongo_collection="person_info",
@@ -260,22 +270,17 @@ class MongoToSQLiteMigrator:
"relationship_value": "relationship_value",
"konw_time": "know_time",
"msg_interval": "msg_interval",
- "msg_interval_list": "msg_interval_list"
+ "msg_interval_list": "msg_interval_list",
},
- unique_fields=["person_id"]
+ unique_fields=["person_id"],
),
-
# 知识库迁移配置
MigrationConfig(
mongo_collection="knowledges",
target_model=Knowledges,
- field_mapping={
- "content": "content",
- "embedding": "embedding"
- },
- unique_fields=["content"] # 假设内容唯一
+ field_mapping={"content": "content", "embedding": "embedding"},
+ unique_fields=["content"], # 假设内容唯一
),
-
# 思考日志迁移配置
MigrationConfig(
mongo_collection="thinking_log",
@@ -293,9 +298,8 @@ class MongoToSQLiteMigrator:
"heartflow_data": "heartflow_data_json",
"reasoning_data": "reasoning_data_json",
},
- unique_fields=["chat_id", "trigger_text"]
+ unique_fields=["chat_id", "trigger_text"],
),
-
# 图节点迁移配置
MigrationConfig(
mongo_collection="graph_data.nodes",
@@ -305,11 +309,10 @@ class MongoToSQLiteMigrator:
"memory_items": "memory_items",
"hash": "hash",
"created_time": "created_time",
- "last_modified": "last_modified"
+ "last_modified": "last_modified",
},
- unique_fields=["concept"]
+ unique_fields=["concept"],
),
-
# 图边迁移配置
MigrationConfig(
mongo_collection="graph_data.edges",
@@ -320,71 +323,69 @@ class MongoToSQLiteMigrator:
"strength": "strength",
"hash": "hash",
"created_time": "created_time",
- "last_modified": "last_modified"
+ "last_modified": "last_modified",
},
- unique_fields=["source", "target"] # 组合唯一性
- )
+ unique_fields=["source", "target"], # 组合唯一性
+ ),
]
+
def _initialize_validation_rules(self) -> Dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
-
+
def connect_mongodb(self) -> bool:
"""连接到MongoDB"""
try:
self.mongo_client = MongoClient(
- self.mongo_uri,
- serverSelectionTimeoutMS=5000,
- connectTimeoutMS=10000,
- maxPoolSize=10
+ self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
)
-
+
# 测试连接
- self.mongo_client.admin.command('ping')
+ self.mongo_client.admin.command("ping")
self.mongo_db = self.mongo_client[self.database_name]
-
+
logger.info(f"成功连接到MongoDB: {self.database_name}")
return True
-
+
except ConnectionFailure as e:
logger.error(f"MongoDB连接失败: {e}")
return False
except Exception as e:
logger.error(f"MongoDB连接异常: {e}")
return False
-
+
def disconnect_mongodb(self):
"""断开MongoDB连接"""
if self.mongo_client:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
-
+
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
"""获取嵌套字段的值"""
if "." not in field_path:
return document.get(field_path)
-
+
parts = field_path.split(".")
value = document
-
+
for part in parts:
if isinstance(value, dict):
value = value.get(part)
else:
return None
-
+
if value is None:
break
-
+
return value
-
+
def _convert_field_value(self, value: Any, target_field: Field) -> Any:
"""根据目标字段类型转换值"""
if value is None:
return None
-
+
field_type = target_field.__class__.__name__
-
+
try:
if target_field.name == "record_time" and field_type == "DateTimeField":
return datetime.now()
@@ -393,31 +394,31 @@ class MongoToSQLiteMigrator:
if isinstance(value, (list, dict)):
return json.dumps(value, ensure_ascii=False)
return str(value) if value is not None else ""
-
+
elif field_type == "IntegerField":
if isinstance(value, str):
# 处理字符串数字
clean_value = value.strip()
- if clean_value.replace('.', '').replace('-', '').isdigit():
+ if clean_value.replace(".", "").replace("-", "").isdigit():
return int(float(clean_value))
return 0
return int(value) if value is not None else 0
-
+
elif field_type in ["FloatField", "DoubleField"]:
return float(value) if value is not None else 0.0
-
+
elif field_type == "BooleanField":
if isinstance(value, str):
- return value.lower() in ('true', '1', 'yes', 'on')
+ return value.lower() in ("true", "1", "yes", "on")
return bool(value)
-
+
elif field_type == "DateTimeField":
if isinstance(value, (int, float)):
return datetime.fromtimestamp(value)
elif isinstance(value, str):
try:
# 尝试解析ISO格式日期
- return datetime.fromisoformat(value.replace('Z', '+00:00'))
+ return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
try:
# 尝试解析时间戳字符串
@@ -425,23 +426,23 @@ class MongoToSQLiteMigrator:
except ValueError:
return datetime.now()
return datetime.now()
-
+
return value
-
+
except (ValueError, TypeError) as e:
logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
return self._get_default_value_for_field(target_field)
-
+
def _get_default_value_for_field(self, field: Field) -> Any:
"""获取字段的默认值"""
field_type = field.__class__.__name__
-
- if hasattr(field, 'default') and field.default is not None:
+
+ if hasattr(field, "default") and field.default is not None:
return field.default
-
+
if field.null:
return None
-
+
# 根据字段类型返回默认值
if field_type in ["CharField", "TextField"]:
return ""
@@ -450,56 +451,57 @@ class MongoToSQLiteMigrator:
elif field_type in ["FloatField", "DoubleField"]:
return 0.0
elif field_type == "BooleanField":
- return False
+ return False
elif field_type == "DateTimeField":
return datetime.now()
-
+
return None
+
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
-
+
def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
"""保存迁移断点"""
checkpoint = MigrationCheckpoint(
collection_name=collection_name,
processed_count=processed_count,
last_processed_id=last_id,
- timestamp=datetime.now()
+ timestamp=datetime.now(),
)
-
+
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
try:
- with open(checkpoint_file, 'wb') as f:
+ with open(checkpoint_file, "wb") as f:
pickle.dump(checkpoint, f)
except Exception as e:
logger.warning(f"保存断点失败: {e}")
-
+
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
"""加载迁移断点"""
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
if not checkpoint_file.exists():
return None
-
+
try:
- with open(checkpoint_file, 'rb') as f:
+ with open(checkpoint_file, "rb") as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"加载断点失败: {e}")
return None
-
+
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
"""批量插入数据"""
if not data_list:
return 0
-
+
success_count = 0
try:
with db.atomic():
# 分批插入,避免SQL语句过长
batch_size = 100
for i in range(0, len(data_list), batch_size):
- batch = data_list[i:i + batch_size]
+ batch = data_list[i : i + batch_size]
model.insert_many(batch).execute()
success_count += len(batch)
except Exception as e:
@@ -511,27 +513,28 @@ class MongoToSQLiteMigrator:
success_count += 1
except Exception:
pass # 忽略单个插入失败
-
+
return success_count
-
- def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any],
- unique_fields: List[str]) -> bool:
+
+ def _check_duplicate_by_unique_fields(
+ self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
+ ) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
return False
-
+
try:
query = model.select()
for field_name in unique_fields:
if field_name in data and data[field_name] is not None:
field_obj = getattr(model, field_name)
query = query.where(field_obj == data[field_name])
-
+
return query.exists()
except Exception as e:
logger.debug(f"重复检查失败: {e}")
return False
-
+
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
"""使用ORM创建模型实例"""
try:
@@ -542,11 +545,11 @@ class MongoToSQLiteMigrator:
valid_data[field_name] = value
else:
logger.debug(f"跳过未知字段: {field_name}")
-
+
# 创建实例
instance = model.create(**valid_data)
return instance
-
+
except IntegrityError as e:
# 处理唯一约束冲突等完整性错误
logger.debug(f"完整性约束冲突: {e}")
@@ -554,37 +557,38 @@ class MongoToSQLiteMigrator:
except Exception as e:
logger.error(f"创建模型实例失败: {e}")
return None
+
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
"""迁移单个集合 - 使用优化的批量插入和进度条"""
stats = MigrationStats()
stats.start_time = datetime.now()
-
+
# 检查是否有断点
checkpoint = self._load_checkpoint(config.mongo_collection)
start_from_id = checkpoint.last_processed_id if checkpoint else None
if checkpoint:
stats.processed_count = checkpoint.processed_count
logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
-
+
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
-
+
try:
# 获取MongoDB集合
mongo_collection = self.mongo_db[config.mongo_collection]
-
+
# 构建查询条件(用于断点恢复)
query = {}
if start_from_id:
query = {"_id": {"$gt": start_from_id}}
-
+
stats.total_documents = mongo_collection.count_documents(query)
-
+
if stats.total_documents == 0:
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
return stats
-
+
logger.info(f"待迁移文档数量: {stats.total_documents}")
-
+
# 创建Rich进度条
with Progress(
SpinnerColumn(),
@@ -594,40 +598,36 @@ class MongoToSQLiteMigrator:
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
- refresh_per_second=10
+ refresh_per_second=10,
) as progress:
-
- task = progress.add_task(
- f"迁移 {config.mongo_collection}",
- total=stats.total_documents
- )
- # 批量处理数据
+ task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
+ # 批量处理数据
batch_data = []
batch_count = 0
last_processed_id = None
-
+
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
try:
- doc_id = mongo_doc.get('_id', 'unknown')
+ doc_id = mongo_doc.get("_id", "unknown")
last_processed_id = doc_id
-
+
# 构建目标数据
target_data = {}
for mongo_field, sqlite_field in config.field_mapping.items():
value = self._get_nested_value(mongo_doc, mongo_field)
-
+
# 获取目标字段对象并转换类型
if hasattr(config.target_model, sqlite_field):
field_obj = getattr(config.target_model, sqlite_field)
converted_value = self._convert_field_value(value, field_obj)
target_data[sqlite_field] = converted_value
-
+
# 数据验证已禁用
# if config.enable_validation:
# if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
# stats.skipped_count += 1
# continue
-
+
# 重复检查
if config.skip_duplicates and self._check_duplicate_by_unique_fields(
config.target_model, target_data, config.unique_fields
@@ -636,88 +636,93 @@ class MongoToSQLiteMigrator:
stats.skipped_count += 1
logger.debug(f"跳过重复记录: {doc_id}")
continue
-
+
# 添加到批量数据
batch_data.append(target_data)
stats.processed_count += 1
-
+
# 执行批量插入
if len(batch_data) >= config.batch_size:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
-
+
# 保存断点
self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
-
+
batch_data.clear()
batch_count += 1
-
+
# 更新进度条
progress.update(task, advance=config.batch_size)
-
+
except Exception as e:
- doc_id = mongo_doc.get('_id', 'unknown')
+ doc_id = mongo_doc.get("_id", "unknown")
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
-
+
# 处理剩余的批量数据
if batch_data:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
progress.update(task, advance=len(batch_data))
-
+
# 完成进度条
progress.update(task, completed=stats.total_documents)
-
+
stats.end_time = datetime.now()
duration = stats.end_time - stats.start_time
-
+
logger.info(
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
)
-
+
# 清理断点文件
checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
if checkpoint_file.exists():
checkpoint_file.unlink()
-
+
except Exception as e:
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
stats.add_error("collection_error", str(e))
-
+
return stats
+
def migrate_all(self) -> Dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
-
+
if not self.connect_mongodb():
logger.error("无法连接到MongoDB,迁移终止")
return {}
-
+
all_stats = {}
-
+
try:
# 创建总体进度表格
total_collections = len(self.migration_configs)
- self.console.print(Panel(
- f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
- f"[yellow]总集合数: {total_collections}[/yellow]",
- title="迁移开始",
- expand=False
- ))
+ self.console.print(
+ Panel(
+ f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
+ f"[yellow]总集合数: {total_collections}[/yellow]",
+ title="迁移开始",
+ expand=False,
+ )
+ )
for idx, config in enumerate(self.migration_configs, 1):
- self.console.print(f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]")
+ self.console.print(
+ f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
+ )
stats = self.migrate_collection(config)
all_stats[config.mongo_collection] = stats
-
+
# 显示单个集合的快速统计
if stats.processed_count > 0:
- success_rate = (stats.success_count / stats.processed_count * 100)
+ success_rate = stats.success_count / stats.processed_count * 100
if success_rate >= 95:
status_emoji = "✅"
status_color = "bright_green"
@@ -727,12 +732,12 @@ class MongoToSQLiteMigrator:
else:
status_emoji = "❌"
status_color = "red"
-
+
self.console.print(
f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
)
-
+
# 错误率检查
if stats.processed_count > 0:
error_rate = stats.error_count / stats.processed_count
@@ -741,13 +746,13 @@ class MongoToSQLiteMigrator:
f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
f"({stats.error_count}/{stats.processed_count})[/red]"
)
-
+
finally:
self.disconnect_mongodb()
-
+
self._print_migration_summary(all_stats)
return all_stats
-
+
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
"""使用Rich打印美观的迁移汇总信息"""
# 计算总体统计
@@ -758,14 +763,14 @@ class MongoToSQLiteMigrator:
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
-
+
# 计算总耗时
total_duration_seconds = 0
for stats in all_stats.values():
if stats.start_time and stats.end_time:
duration = stats.end_time - stats.start_time
total_duration_seconds += duration.total_seconds()
-
+
# 创建详细统计表格
table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
table.add_column("集合名称", style="cyan", width=20)
@@ -779,13 +784,13 @@ class MongoToSQLiteMigrator:
table.add_column("批次数", justify="right", style="purple")
table.add_column("成功率", justify="right", style="bright_green")
table.add_column("耗时(秒)", justify="right", style="blue")
-
+
for collection_name, stats in all_stats.items():
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
duration = 0
if stats.start_time and stats.end_time:
duration = (stats.end_time - stats.start_time).total_seconds()
-
+
# 根据成功率设置颜色
if success_rate >= 95:
success_rate_style = "[bright_green]"
@@ -793,7 +798,7 @@ class MongoToSQLiteMigrator:
success_rate_style = "[yellow]"
else:
success_rate_style = "[red]"
-
+
table.add_row(
collection_name,
str(stats.total_documents),
@@ -805,9 +810,9 @@ class MongoToSQLiteMigrator:
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
str(stats.batch_insert_count),
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
- f"{duration:.2f}"
+ f"{duration:.2f}",
)
-
+
# 添加总计行
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
if total_success_rate >= 95:
@@ -816,7 +821,7 @@ class MongoToSQLiteMigrator:
total_rate_style = "[yellow]"
else:
total_rate_style = "[red]"
-
+
table.add_section()
table.add_row(
"[bold]总计[/bold]",
@@ -825,41 +830,41 @@ class MongoToSQLiteMigrator:
f"[bold]{total_success}[/bold]",
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
- f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" if total_duplicates > 0 else "[bold]0[/bold]",
+ f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
+ if total_duplicates > 0
+ else "[bold]0[/bold]",
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
f"[bold]{total_batch_inserts}[/bold]",
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
- f"[bold]{total_duration_seconds:.2f}[/bold]"
+ f"[bold]{total_duration_seconds:.2f}[/bold]",
)
-
+
self.console.print(table)
-
+
# 创建状态面板
status_items = []
if total_errors > 0:
status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
-
+
if total_validation_errors > 0:
status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
-
+
if total_duplicates > 0:
status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
-
+
if total_success_rate >= 95:
status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
elif total_success_rate >= 80:
status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
else:
status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
-
+
if status_items:
status_panel = Panel(
- "\n".join(status_items),
- title="[bold yellow]迁移状态总结[/bold yellow]",
- border_style="yellow"
+ "\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
)
self.console.print(status_panel)
-
+
# 性能统计面板
avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
performance_info = (
@@ -867,60 +872,52 @@ class MongoToSQLiteMigrator:
f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
)
-
- performance_panel = Panel(
- performance_info,
- title="[bold green]性能统计[/bold green]",
- border_style="green"
- )
+
+ performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
self.console.print(performance_panel)
-
+
def add_migration_config(self, config: MigrationConfig):
"""添加新的迁移配置"""
self.migration_configs.append(config)
-
+
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
"""迁移单个指定的集合"""
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
if not config:
logger.error(f"未找到集合 {collection_name} 的迁移配置")
return None
-
+
if not self.connect_mongodb():
logger.error("无法连接到MongoDB")
return None
-
+
try:
stats = self.migrate_collection(config)
self._print_migration_summary({collection_name: stats})
return stats
finally:
self.disconnect_mongodb()
-
+
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
- 'timestamp': datetime.now().isoformat(),
- 'summary': {
+ "timestamp": datetime.now().isoformat(),
+ "summary": {
collection: {
- 'total': stats.total_documents,
- 'processed': stats.processed_count,
- 'success': stats.success_count,
- 'errors': stats.error_count,
- 'skipped': stats.skipped_count,
- 'duplicates': stats.duplicate_count
+ "total": stats.total_documents,
+ "processed": stats.processed_count,
+ "success": stats.success_count,
+ "errors": stats.error_count,
+ "skipped": stats.skipped_count,
+ "duplicates": stats.duplicate_count,
}
for collection, stats in all_stats.items()
},
- 'errors': {
- collection: stats.errors
- for collection, stats in all_stats.items()
- if stats.errors
- }
+ "errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
}
-
+
try:
- with open(filepath, 'w', encoding='utf-8') as f:
+ with open(filepath, "w", encoding="utf-8") as f:
json.dump(error_report, f, ensure_ascii=False, indent=2)
logger.info(f"错误报告已导出到: {filepath}")
except Exception as e:
@@ -930,17 +927,17 @@ class MongoToSQLiteMigrator:
def main():
"""主程序入口"""
migrator = MongoToSQLiteMigrator()
-
+
# 执行迁移
migration_results = migrator.migrate_all()
-
+
# 导出错误报告(如果有错误)
if any(stats.error_count > 0 for stats in migration_results.values()):
error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
migrator.export_error_report(migration_results, error_report_path)
-
+
logger.info("数据迁移完成!")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
From f528d83826f6000f09a78a31af8afed8f668918c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com>
Date: Wed, 28 May 2025 20:52:15 +0800
Subject: [PATCH 11/13] =?UTF-8?q?=E9=87=8D=E5=86=99=20mongodb=5Fto=5Fsqlit?=
=?UTF-8?q?e.bat=20=E8=84=9A=E6=9C=AC=EF=BC=8C=E5=A2=9E=E5=8A=A0=E7=8E=AF?=
=?UTF-8?q?=E5=A2=83=E9=80=89=E6=8B=A9=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF?=
=?UTF-8?q?=E6=8C=81=20Conda=20=E5=92=8C=20venv=20=E6=BF=80=E6=B4=BB?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mongodb_to_sqlite .bat => mongodb_to_sqlite.bat | 0
1 file changed, 0 insertions(+), 0 deletions(-)
rename mongodb_to_sqlite .bat => mongodb_to_sqlite.bat (100%)
diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite.bat
similarity index 100%
rename from mongodb_to_sqlite .bat
rename to mongodb_to_sqlite.bat
From 9bb2fe2d52777e54658fa7bc06409f48663ebf9c Mon Sep 17 00:00:00 2001
From: SengokuCola <1026294844@qq.com>
Date: Wed, 28 May 2025 21:10:09 +0800
Subject: [PATCH 12/13] =?UTF-8?q?feat=EF=BC=9A=E4=B8=BAnoraml=5Fcaht?=
=?UTF-8?q?=E5=8A=A0=E5=85=A5=E8=A1=A8=E8=BE=BE=E6=96=B9=E5=BC=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
changelogs/changelog.md | 2 +-
src/chat/normal_chat/normal_chat.py | 3 +-
src/chat/normal_chat/normal_prompt.py | 99 +++++++++++++++++++++++----
template/bot_config_template.toml | 13 ++--
4 files changed, 95 insertions(+), 22 deletions(-)
diff --git a/changelogs/changelog.md b/changelogs/changelog.md
index c4dc8b3b..461bd7bb 100644
--- a/changelogs/changelog.md
+++ b/changelogs/changelog.md
@@ -36,7 +36,7 @@
- 优化了进入和离开normal_chat的方式
**新增表达方式学习**
-- 在专注模式下,麦麦可以有独特的表达方式
+- 麦麦配置单独表达方式
- 自主学习群聊中的表达方式,更贴近群友
- 可自定义的学习频率和开关
- 根据人设生成额外的表达方式
diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py
index 98478183..dc4da2ea 100644
--- a/src/chat/normal_chat/normal_chat.py
+++ b/src/chat/normal_chat/normal_chat.py
@@ -337,7 +337,8 @@ class NormalChat:
self.recent_replies = self.recent_replies[-self.max_replies_history :]
# 检查是否需要切换到focus模式
- await self._check_switch_to_focus()
+ if global_config.chat.chat_mode == "auto":
+ await self._check_switch_to_focus()
info_catcher.done_catch()
diff --git a/src/chat/normal_chat/normal_prompt.py b/src/chat/normal_chat/normal_prompt.py
index 54624a02..8906106a 100644
--- a/src/chat/normal_chat/normal_prompt.py
+++ b/src/chat/normal_chat/normal_prompt.py
@@ -10,6 +10,7 @@ from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
+from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import random
@@ -24,6 +25,11 @@ def init_prompt():
Prompt(
"""
+你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
+{style_habbits}
+请你根据情景使用以下句法,不要盲目使用,不要生硬使用,而是结合到表达中:
+{grammar_habbits}
+
{memory_prompt}
{relation_prompt}
{prompt_info}
@@ -31,7 +37,7 @@ def init_prompt():
{chat_talking_prompt}
现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言或者回复这条消息。\n
你的网名叫{bot_name},有人也叫你{bot_other_names},{prompt_personality}。
-你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1},
+你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},请你给出回复
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。
@@ -49,6 +55,11 @@ def init_prompt():
Prompt(
"""
+你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
+{style_habbits}
+请你根据情景使用以下句法,不要盲目使用,不要生硬使用,而是结合到表达中:
+{grammar_habbits}
+
{memory_prompt}
{relation_prompt}
{prompt_info}
@@ -58,7 +69,7 @@ def init_prompt():
现在 {sender_name} 说的: {message_txt} 引起了你的注意,你想要回复这条消息。
你的网名叫{bot_name},有人也叫你{bot_other_names},{prompt_personality}。
-你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,{mood_prompt},{reply_style1},
+你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,{mood_prompt},请你给出回复
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号等),只输出回复内容。
@@ -103,15 +114,42 @@ class PromptBuilder:
relation_prompt += await relationship_manager.build_relationship_info(person)
mood_prompt = mood_manager.get_mood_prompt()
- reply_styles1 = [
- ("然后给出日常且口语化的回复,平淡一些", 0.4),
- ("给出非常简短的回复", 0.4),
- ("给出缺失主语的回复", 0.15),
- ("给出带有语病的回复", 0.05),
- ]
- reply_style1_chosen = random.choices(
- [style[0] for style in reply_styles1], weights=[style[1] for style in reply_styles1], k=1
- )[0]
+
+
+ (
+ learnt_style_expressions,
+ learnt_grammar_expressions,
+ personality_expressions,
+ ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
+
+ style_habbits = []
+ grammar_habbits = []
+ # 1. learnt_expressions加权随机选2条
+ if learnt_style_expressions:
+ weights = [expr["count"] for expr in learnt_style_expressions]
+ selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 2)
+ for expr in selected_learnt:
+ if isinstance(expr, dict) and "situation" in expr and "style" in expr:
+ style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
+ # 2. learnt_grammar_expressions加权随机选2条
+ if learnt_grammar_expressions:
+ weights = [expr["count"] for expr in learnt_grammar_expressions]
+ selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 2)
+ for expr in selected_learnt:
+ if isinstance(expr, dict) and "situation" in expr and "style" in expr:
+ grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
+ # 3. personality_expressions随机选1条
+ if personality_expressions:
+ expr = random.choice(personality_expressions)
+ if isinstance(expr, dict) and "situation" in expr and "style" in expr:
+ style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
+
+ style_habbits_str = "\n".join(style_habbits)
+ grammar_habbits_str = "\n".join(grammar_habbits)
+
+
+
+
reply_styles2 = [
("不要回复的太有条理,可以有个性", 0.6),
("不要回复的太有条理,可以复读", 0.15),
@@ -208,7 +246,8 @@ class PromptBuilder:
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
- reply_style1=reply_style1_chosen,
+ style_habbits=style_habbits_str,
+ grammar_habbits=grammar_habbits_str,
reply_style2=reply_style2_chosen,
keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger,
@@ -231,7 +270,8 @@ class PromptBuilder:
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
- reply_style1=reply_style1_chosen,
+ style_habbits=style_habbits_str,
+ grammar_habbits=grammar_habbits_str,
reply_style2=reply_style2_chosen,
keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger,
@@ -266,6 +306,39 @@ class PromptBuilder:
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return "未检索到知识"
+
+def weighted_sample_no_replacement(items, weights, k) -> list:
+ """
+ 加权且不放回地随机抽取k个元素。
+
+ 参数:
+ items: 待抽取的元素列表
+ weights: 每个元素对应的权重(与items等长,且为正数)
+ k: 需要抽取的元素个数
+ 返回:
+ selected: 按权重加权且不重复抽取的k个元素组成的列表
+
+ 如果 items 中的元素不足 k 个,就只会返回所有可用的元素
+
+ 实现思路:
+ 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
+ 这样保证了:
+ 1. count越大被选中概率越高
+ 2. 不会重复选中同一个元素
+ """
+ selected = []
+ pool = list(zip(items, weights))
+ for _ in range(min(k, len(pool))):
+ total = sum(w for _, w in pool)
+ r = random.uniform(0, total)
+ upto = 0
+ for idx, (item, weight) in enumerate(pool):
+ upto += weight
+ if upto >= r:
+ selected.append(item)
+ pool.pop(idx)
+ break
+ return selected
init_prompt()
diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml
index 2d4eccb9..435abddf 100644
--- a/template/bot_config_template.toml
+++ b/template/bot_config_template.toml
@@ -38,6 +38,12 @@ identity_detail = [
# 可以描述外貌,性别,身高,职业,属性等等描述
# 条数任意,不能为0
+[expression]
+# 表达方式
+expression_style = "描述麦麦说话的表达风格,表达习惯"
+enable_expression_learning = true # 是否启用表达学习,麦麦会学习人类说话风格
+learning_interval = 600 # 学习间隔 单位秒
+
[relationship]
give_name = true # 麦麦是否给其他人取名,关闭后无法使用禁言功能
@@ -97,13 +103,6 @@ self_identify_processor = true # 是否启用自我识别处理器
tool_use_processor = false # 是否启用工具使用处理器
working_memory_processor = false # 是否启用工作记忆处理器
-[expression]
-# 表达方式
-expression_style = "描述麦麦说话的表达风格,表达习惯"
-enable_expression_learning = true # 是否启用表达学习
-learning_interval = 600 # 学习间隔 单位秒
-
-
[emoji]
max_reg_num = 40 # 表情包最大注册数量
do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包
From f57903ff7c95b54a80be5a46c2b52a3fb16ce61b Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
Date: Wed, 28 May 2025 13:10:24 +0000
Subject: [PATCH 13/13] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?=
=?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
src/chat/normal_chat/normal_prompt.py | 13 +++++--------
1 file changed, 5 insertions(+), 8 deletions(-)
diff --git a/src/chat/normal_chat/normal_prompt.py b/src/chat/normal_chat/normal_prompt.py
index 8906106a..9618987a 100644
--- a/src/chat/normal_chat/normal_prompt.py
+++ b/src/chat/normal_chat/normal_prompt.py
@@ -114,14 +114,13 @@ class PromptBuilder:
relation_prompt += await relationship_manager.build_relationship_info(person)
mood_prompt = mood_manager.get_mood_prompt()
-
-
+
(
learnt_style_expressions,
learnt_grammar_expressions,
personality_expressions,
) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
-
+
style_habbits = []
grammar_habbits = []
# 1. learnt_expressions加权随机选2条
@@ -146,10 +145,7 @@ class PromptBuilder:
style_habbits_str = "\n".join(style_habbits)
grammar_habbits_str = "\n".join(grammar_habbits)
-
-
-
-
+
reply_styles2 = [
("不要回复的太有条理,可以有个性", 0.6),
("不要回复的太有条理,可以复读", 0.15),
@@ -306,7 +302,8 @@ class PromptBuilder:
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return "未检索到知识"
-
+
+
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素。