From 5136c617ceccf1823e6666bfceded0821cf4cbc1 Mon Sep 17 00:00:00 2001 From: A0000Xz <122650088+A0000Xz@users.noreply.github.com> Date: Wed, 28 May 2025 17:59:45 +0800 Subject: [PATCH 01/13] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=81=E7=A7=BB=E5=B7=A5=E5=85=B7=E7=9A=84?= =?UTF-8?q?=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/mongodb_to_sqlite.py | 692 +++++++++++++++++++++++++++++++++++ 1 file changed, 692 insertions(+) create mode 100644 scripts/mongodb_to_sqlite.py diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py new file mode 100644 index 00000000..83e08679 --- /dev/null +++ b/scripts/mongodb_to_sqlite.py @@ -0,0 +1,692 @@ +import os +import json +import sys # 新增系统模块导入 + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from typing import Dict, Any, List, Optional, Type +from dataclasses import dataclass, field +from datetime import datetime +from pymongo import MongoClient +from pymongo.errors import ConnectionFailure +from peewee import Model, Field, IntegrityError + +from src.common.database.database import db +from src.common.database.database_model import ( + ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions, + OnlineTime, PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges +) +from src.common.logger_manager import get_logger + +logger = get_logger("mongodb_to_sqlite") + + +@dataclass +class MigrationConfig: + """迁移配置类""" + mongo_collection: str + target_model: Type[Model] + field_mapping: Dict[str, str] + batch_size: int = 500 + enable_validation: bool = True + skip_duplicates: bool = True + unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段 + + +@dataclass +class MigrationStats: + """迁移统计信息""" + total_documents: int = 0 + processed_count: int = 0 + success_count: int = 0 + error_count: int = 0 + skipped_count: int = 0 + duplicate_count: int = 0 + errors: List[Dict[str, Any]] = field(default_factory=list) + + def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None): + """添加错误记录""" + self.errors.append({ + 'doc_id': str(doc_id), + 'error': error, + 'timestamp': datetime.now().isoformat(), + 'doc_data': doc_data + }) + self.error_count += 1 + + +class MongoToSQLiteMigrator: + """MongoDB到SQLite数据迁移器 - 使用Peewee ORM""" + + def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None): + self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot") + self.mongo_uri = mongo_uri or self._build_mongo_uri() + self.mongo_client: Optional[MongoClient] = None + self.mongo_db = None + + # 迁移配置 + self.migration_configs = self._initialize_migration_configs() + + def _build_mongo_uri(self) -> str: + """构建MongoDB连接URI""" + if mongo_uri := os.getenv("MONGODB_URI"): + return mongo_uri + + user = os.getenv('MONGODB_USER') + password = os.getenv('MONGODB_PASS') + host = os.getenv('MONGODB_HOST', 'localhost') + port = os.getenv('MONGODB_PORT', '27017') + auth_source = os.getenv('MONGODB_AUTH_SOURCE', 'admin') + + if user and password: + return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}" + else: + return f"mongodb://{host}:{port}/{self.database_name}" + + def _initialize_migration_configs(self) -> List[MigrationConfig]: + """初始化迁移配置""" + return [ + # 表情包迁移配置 + MigrationConfig( + mongo_collection="emoji", + target_model=Emoji, + field_mapping={ + "full_path": "full_path", + "format": "format", + "hash": "emoji_hash", + "description": "description", + "emotion": "emotion", + "usage_count": "usage_count", + "last_used_time": "last_used_time", + "last_used_time": "record_time" # 这个纯粹是为了应付整体映射格式,实际上直接用当前时间戳填了record_time + }, + unique_fields=["full_path", "emoji_hash"] + ), + + # 聊天流迁移配置 + MigrationConfig( + mongo_collection="chat_streams", + target_model=ChatStreams, + field_mapping={ + "stream_id": "stream_id", + "create_time": "create_time", + "group_info.platform": "group_platform",# 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。 + "group_info.group_id": "group_id", # 同上 + "group_info.group_name": "group_name", # 同上 + "last_active_time": "last_active_time", + "platform": "platform", + "user_info.platform": "user_platform", + "user_info.user_id": "user_id", + "user_info.user_nickname": "user_nickname", + "user_info.user_cardname": "user_cardname" + }, + unique_fields=["stream_id"] + ), + + # LLM使用记录迁移配置 + MigrationConfig( + mongo_collection="llm_usage", + target_model=LLMUsage, + field_mapping={ + "model_name": "model_name", + "user_id": "user_id", + "request_type": "request_type", + "endpoint": "endpoint", + "prompt_tokens": "prompt_tokens", + "completion_tokens": "completion_tokens", + "total_tokens": "total_tokens", + "cost": "cost", + "status": "status", + "timestamp": "timestamp" + }, + unique_fields=["user_id", "timestamp"] # 组合唯一性 + ), + + # 消息迁移配置 + MigrationConfig( + mongo_collection="messages", + target_model=Messages, + field_mapping={ + "message_id": "message_id", + "time": "time", + "chat_id": "chat_id", + "chat_info.stream_id": "chat_info_stream_id", + "chat_info.platform": "chat_info_platform", + "chat_info.user_info.platform": "chat_info_user_platform", + "chat_info.user_info.user_id": "chat_info_user_id", + "chat_info.user_info.user_nickname": "chat_info_user_nickname", + "chat_info.user_info.user_cardname": "chat_info_user_cardname", + "chat_info.group_info.platform": "chat_info_group_platform", + "chat_info.group_info.group_id": "chat_info_group_id", + "chat_info.group_info.group_name": "chat_info_group_name", + "chat_info.create_time": "chat_info_create_time", + "chat_info.last_active_time": "chat_info_last_active_time", + "user_info.platform": "user_platform", + "user_info.user_id": "user_id", + "user_info.user_nickname": "user_nickname", + "user_info.user_cardname": "user_cardname", + "processed_plain_text": "processed_plain_text", + "detailed_plain_text": "detailed_plain_text", + "memorized_times": "memorized_times" + }, + unique_fields=["message_id"] + ), + + # 图片迁移配置 + MigrationConfig( + mongo_collection="images", + target_model=Images, + field_mapping={ + "hash": "emoji_hash", + "description": "description", + "path": "path", + "timestamp": "timestamp", + "type": "type" + }, + unique_fields=["path"] + ), + + # 图片描述迁移配置 + MigrationConfig( + mongo_collection="image_descriptions", + target_model=ImageDescriptions, + field_mapping={ + "type": "type", + "hash": "image_description_hash", + "description": "description", + "timestamp": "timestamp" + }, + unique_fields=["image_description_hash", "type"] + ), + + # 在线时长迁移配置 + MigrationConfig( + mongo_collection="online_time", + target_model=OnlineTime, + field_mapping={ + "timestamp": "timestamp", + "duration": "duration", + "start_timestamp": "start_timestamp", + "end_timestamp": "end_timestamp" + }, + unique_fields=["start_timestamp", "end_timestamp"] + ), + + # 个人信息迁移配置 + MigrationConfig( + mongo_collection="person_info", + target_model=PersonInfo, + field_mapping={ + "person_id": "person_id", + "person_name": "person_name", + "name_reason": "name_reason", + "platform": "platform", + "user_id": "user_id", + "nickname": "nickname", + "relationship_value": "relationship_value", + "konw_time": "know_time", + "msg_interval": "msg_interval", + "msg_interval_list": "msg_interval_list" + }, + unique_fields=["person_id"] + ), + + # 知识库迁移配置 + MigrationConfig( + mongo_collection="knowledges", + target_model=Knowledges, + field_mapping={ + "content": "content", + "embedding": "embedding" + }, + unique_fields=["content"] # 假设内容唯一 + ), + + # 思考日志迁移配置 + MigrationConfig( + mongo_collection="thinking_log", + target_model=ThinkingLog, + field_mapping={ + "chat_id": "chat_id", + "trigger_text": "trigger_text", + "response_text": "response_text", + "trigger_info": "trigger_info_json", + "response_info": "response_info_json", + "timing_results": "timing_results_json", + "chat_history": "chat_history_json", + "chat_history_in_thinking": "chat_history_in_thinking_json", + "chat_history_after_response": "chat_history_after_response_json", + "heartflow_data": "heartflow_data_json", + "reasoning_data": "reasoning_data_json", + }, + unique_fields=["chat_id", "created_at"] + ), + + # 图节点迁移配置 + MigrationConfig( + mongo_collection="graph_data.nodes", + target_model=GraphNodes, + field_mapping={ + "concept": "concept", + "memory_items": "memory_items", + "hash": "hash", + "created_time": "created_time", + "last_modified": "last_modified" + }, + unique_fields=["concept"] + ), + + # 图边迁移配置 + MigrationConfig( + mongo_collection="graph_data.edges", + target_model=GraphEdges, + field_mapping={ + "source": "source", + "target": "target", + "strength": "strength", + "hash": "hash", + "created_time": "created_time", + "last_modified": "last_modified" + }, + unique_fields=["source", "target"] # 组合唯一性 + ) + ] + + def connect_mongodb(self) -> bool: + """连接到MongoDB""" + try: + self.mongo_client = MongoClient( + self.mongo_uri, + serverSelectionTimeoutMS=5000, + connectTimeoutMS=10000, + maxPoolSize=10 + ) + + # 测试连接 + self.mongo_client.admin.command('ping') + self.mongo_db = self.mongo_client[self.database_name] + + logger.info(f"成功连接到MongoDB: {self.database_name}") + return True + + except ConnectionFailure as e: + logger.error(f"MongoDB连接失败: {e}") + return False + except Exception as e: + logger.error(f"MongoDB连接异常: {e}") + return False + + def disconnect_mongodb(self): + """断开MongoDB连接""" + if self.mongo_client: + self.mongo_client.close() + logger.info("MongoDB连接已关闭") + + def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any: + """获取嵌套字段的值""" + if "." not in field_path: + return document.get(field_path) + + parts = field_path.split(".") + value = document + + for part in parts: + if isinstance(value, dict): + value = value.get(part) + else: + return None + + if value is None: + break + + return value + + def _convert_field_value(self, value: Any, target_field: Field) -> Any: + """根据目标字段类型转换值""" + if value is None: + return None + + field_type = target_field.__class__.__name__ + + try: + if target_field.name == "record_time" and field_type == "DateTimeField": + return datetime.now() + + if target_field.name == "record_time" and field_type == "DateTimeField": + return self._convert_record_time(value) + + if field_type in ["CharField", "TextField"]: + if isinstance(value, (list, dict)): + return json.dumps(value, ensure_ascii=False) + return str(value) if value is not None else "" + + elif field_type == "IntegerField": + if isinstance(value, str): + # 处理字符串数字 + clean_value = value.strip() + if clean_value.replace('.', '').replace('-', '').isdigit(): + return int(float(clean_value)) + return 0 + return int(value) if value is not None else 0 + + elif field_type in ["FloatField", "DoubleField"]: + return float(value) if value is not None else 0.0 + + elif field_type == "BooleanField": + if isinstance(value, str): + return value.lower() in ('true', '1', 'yes', 'on') + return bool(value) + + elif field_type == "DateTimeField": + if isinstance(value, (int, float)): + return datetime.fromtimestamp(value) + elif isinstance(value, str): + try: + # 尝试解析ISO格式日期 + return datetime.fromisoformat(value.replace('Z', '+00:00')) + except ValueError: + try: + # 尝试解析时间戳字符串 + return datetime.fromtimestamp(float(value)) + except ValueError: + return datetime.now() + return datetime.now() + + return value + + except (ValueError, TypeError) as e: + logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}") + return self._get_default_value_for_field(target_field) + + def _get_default_value_for_field(self, field: Field) -> Any: + """获取字段的默认值""" + field_type = field.__class__.__name__ + + if hasattr(field, 'default') and field.default is not None: + return field.default + + if field.null: + return None + + # 根据字段类型返回默认值 + if field_type in ["CharField", "TextField"]: + return "" + elif field_type == "IntegerField": + return 0 + elif field_type in ["FloatField", "DoubleField"]: + return 0.0 + elif field_type == "BooleanField": + return False + elif field_type == "DateTimeField": + return datetime.now() + + return None + + def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any], + unique_fields: List[str]) -> bool: + """根据唯一字段检查重复""" + if not unique_fields: + return False + + try: + query = model.select() + for field_name in unique_fields: + if field_name in data and data[field_name] is not None: + field_obj = getattr(model, field_name) + query = query.where(field_obj == data[field_name]) + + return query.exists() + except Exception as e: + logger.debug(f"重复检查失败: {e}") + return False + + def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]: + """使用ORM创建模型实例""" + try: + # 过滤掉不存在的字段 + valid_data = {} + for field_name, value in data.items(): + if hasattr(model, field_name): + valid_data[field_name] = value + else: + logger.debug(f"跳过未知字段: {field_name}") + + # 创建实例 + instance = model.create(**valid_data) + return instance + + except IntegrityError as e: + # 处理唯一约束冲突等完整性错误 + logger.debug(f"完整性约束冲突: {e}") + return None + except Exception as e: + logger.error(f"创建模型实例失败: {e}") + return None + + def migrate_collection(self, config: MigrationConfig) -> MigrationStats: + """迁移单个集合 - 使用ORM方式""" + stats = MigrationStats() + + logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}") + + try: + # 获取MongoDB集合 + mongo_collection = self.mongo_db[config.mongo_collection] + stats.total_documents = mongo_collection.count_documents({}) + + if stats.total_documents == 0: + logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移") + return stats + + logger.info(f"待迁移文档数量: {stats.total_documents}") + + # 逐个处理文档 + batch_count = 0 + for mongo_doc in mongo_collection.find().batch_size(config.batch_size): + try: + stats.processed_count += 1 + doc_id = mongo_doc.get('_id', 'unknown') + + # 构建目标数据 + target_data = {} + for mongo_field, sqlite_field in config.field_mapping.items(): + value = self._get_nested_value(mongo_doc, mongo_field) + + # 获取目标字段对象并转换类型 + if hasattr(config.target_model, sqlite_field): + field_obj = getattr(config.target_model, sqlite_field) + converted_value = self._convert_field_value(value, field_obj) + target_data[sqlite_field] = converted_value + + # 重复检查 + if config.skip_duplicates and self._check_duplicate_by_unique_fields( + config.target_model, target_data, config.unique_fields + ): + stats.duplicate_count += 1 + stats.skipped_count += 1 + logger.debug(f"跳过重复记录: {doc_id}") + continue + + # 使用ORM创建实例 + with db.atomic(): # 每个实例的事务保护 + instance = self._create_model_instance(config.target_model, target_data) + + if instance: + stats.success_count += 1 + else: + stats.add_error(doc_id, "ORM创建实例失败", target_data) + + except Exception as e: + doc_id = mongo_doc.get('_id', 'unknown') + stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) + logger.error(f"处理文档失败 (ID: {doc_id}): {e}") + + # 进度报告 + batch_count += 1 + if batch_count % config.batch_size == 0: + progress = (stats.processed_count / stats.total_documents) * 100 + logger.info( + f"迁移进度: {stats.processed_count}/{stats.total_documents} " + f"({progress:.1f}%) - 成功: {stats.success_count}, " + f"错误: {stats.error_count}, 跳过: {stats.skipped_count}" + ) + + logger.info( + f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n" + f"总计: {stats.total_documents}, 成功: {stats.success_count}, " + f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}" + ) + + except Exception as e: + logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}") + stats.add_error("collection_error", str(e)) + + return stats + + def migrate_all(self) -> Dict[str, MigrationStats]: + """执行所有迁移任务""" + logger.info("开始执行数据库迁移...") + + if not self.connect_mongodb(): + logger.error("无法连接到MongoDB,迁移终止") + return {} + + all_stats = {} + + try: + for config in self.migration_configs: + logger.info(f"\n开始处理集合: {config.mongo_collection}") + stats = self.migrate_collection(config) + all_stats[config.mongo_collection] = stats + + # 错误率检查 + if stats.processed_count > 0: + error_rate = stats.error_count / stats.processed_count + if error_rate > 0.1: # 错误率超过10% + logger.warning( + f"集合 {config.mongo_collection} 错误率较高: {error_rate:.1%} " + f"({stats.error_count}/{stats.processed_count})" + ) + + finally: + self.disconnect_mongodb() + + self._print_migration_summary(all_stats) + return all_stats + + def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]): + """打印迁移汇总信息""" + logger.info("\n" + "="*60) + logger.info("数据迁移汇总报告") + logger.info("="*60) + + total_processed = sum(stats.processed_count for stats in all_stats.values()) + total_success = sum(stats.success_count for stats in all_stats.values()) + total_errors = sum(stats.error_count for stats in all_stats.values()) + total_skipped = sum(stats.skipped_count for stats in all_stats.values()) + total_duplicates = sum(stats.duplicate_count for stats in all_stats.values()) + + # 表头 + logger.info(f"{'集合名称':<20} | {'处理':<6} | {'成功':<6} | {'错误':<6} | {'跳过':<6} | {'重复':<6} | {'成功率':<8}") + logger.info("-" * 75) + + for collection_name, stats in all_stats.items(): + success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0 + logger.info( + f"{collection_name:<20} | " + f"{stats.processed_count:<6} | " + f"{stats.success_count:<6} | " + f"{stats.error_count:<6} | " + f"{stats.skipped_count:<6} | " + f"{stats.duplicate_count:<6} | " + f"{success_rate:<7.1f}%" + ) + + logger.info("-" * 75) + total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0 + logger.info( + f"{'总计':<20} | " + f"{total_processed:<6} | " + f"{total_success:<6} | " + f"{total_errors:<6} | " + f"{total_skipped:<6} | " + f"{total_duplicates:<6} | " + f"{total_success_rate:<7.1f}%" + ) + + if total_errors > 0: + logger.warning(f"\n⚠️ 存在 {total_errors} 个错误,请检查日志详情") + + if total_duplicates > 0: + logger.info(f"ℹ️ 跳过了 {total_duplicates} 个重复记录") + + logger.info("="*60) + + def add_migration_config(self, config: MigrationConfig): + """添加新的迁移配置""" + self.migration_configs.append(config) + + def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]: + """迁移单个指定的集合""" + config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None) + if not config: + logger.error(f"未找到集合 {collection_name} 的迁移配置") + return None + + if not self.connect_mongodb(): + logger.error("无法连接到MongoDB") + return None + + try: + stats = self.migrate_collection(config) + self._print_migration_summary({collection_name: stats}) + return stats + finally: + self.disconnect_mongodb() + + def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str): + """导出错误报告""" + error_report = { + 'timestamp': datetime.now().isoformat(), + 'summary': { + collection: { + 'total': stats.total_documents, + 'processed': stats.processed_count, + 'success': stats.success_count, + 'errors': stats.error_count, + 'skipped': stats.skipped_count, + 'duplicates': stats.duplicate_count + } + for collection, stats in all_stats.items() + }, + 'errors': { + collection: stats.errors + for collection, stats in all_stats.items() + if stats.errors + } + } + + try: + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(error_report, f, ensure_ascii=False, indent=2) + logger.info(f"错误报告已导出到: {filepath}") + except Exception as e: + logger.error(f"导出错误报告失败: {e}") + + +def main(): + """主程序入口""" + migrator = MongoToSQLiteMigrator() + + # 执行迁移 + migration_results = migrator.migrate_all() + + # 导出错误报告(如果有错误) + if any(stats.error_count > 0 for stats in migration_results.values()): + error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + migrator.export_error_report(migration_results, error_report_path) + + logger.info("数据迁移完成!") + + +if __name__ == "__main__": + main() \ No newline at end of file From bd58fa48ce000f2c6fbc4aad60e30835a2c737a6 Mon Sep 17 00:00:00 2001 From: A0000Xz <122650088+A0000Xz@users.noreply.github.com> Date: Wed, 28 May 2025 18:00:22 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=81=E7=A7=BB=E5=B7=A5=E5=85=B7=E7=9A=84?= =?UTF-8?q?=E5=90=AF=E5=8A=A8=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mongodb_to_sqlite .bat | 72 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 mongodb_to_sqlite .bat diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite .bat new file mode 100644 index 00000000..bc4c5855 --- /dev/null +++ b/mongodb_to_sqlite .bat @@ -0,0 +1,72 @@ +@echo off +CHCP 65001 > nul +setlocal enabledelayedexpansion + +echo 你需要选择启动方式,输入字母来选择: +echo V = 不知道什么意思就输入 V +echo C = 输入 C 使用 Conda 环境 +echo. +choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V + +set "ENV_TYPE=" +if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA" +if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV" + +if "%ENV_TYPE%" == "CONDA" goto activate_conda +if "%ENV_TYPE%" == "VENV" goto activate_venv + +REM 如果 choice 超时或返回意外值,默认使用 venv +echo WARN: Invalid selection or timeout from choice. Defaulting to VENV. +set "ENV_TYPE=VENV" +goto activate_venv + +:activate_conda + set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: " + if not defined CONDA_ENV_NAME ( + echo 错误: 未输入 Conda 环境名称. + pause + exit /b 1 + ) + echo 选择: Conda '!CONDA_ENV_NAME!' + REM 激活Conda环境 + call conda activate !CONDA_ENV_NAME! + if !ERRORLEVEL! neq 0 ( + echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在. + pause + exit /b 1 + ) + goto env_activated + +:activate_venv + echo Selected: venv (default or selected) + REM 查找venv虚拟环境 + set "venv_path=%~dp0venv\Scripts\activate.bat" + if not exist "%venv_path%" ( + echo Error: venv not found. Ensure the venv directory exists alongside the script. + pause + exit /b 1 + ) + REM 激活虚拟环境 + call "%venv_path%" + if %ERRORLEVEL% neq 0 ( + echo Error: Failed to activate venv virtual environment. + pause + exit /b 1 + ) + goto env_activated + +:env_activated +echo Environment activated successfully! + +REM --- 后续脚本执行 --- + +REM 运行预处理脚本 +python "%~dp0scripts\mongodb_to_sqlite.py" +if %ERRORLEVEL% neq 0 ( + echo Error: raw_data_preprocessor.py execution failed. + pause + exit /b 1 +) + +echo All processing steps completed! +pause \ No newline at end of file From 4ca37440720ad2b9100555d4ae118b4421224f03 Mon Sep 17 00:00:00 2001 From: A0000Xz <122650088+A0000Xz@users.noreply.github.com> Date: Wed, 28 May 2025 18:31:59 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E7=BB=86=E8=8A=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mongodb_to_sqlite .bat | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite .bat index bc4c5855..5c029e44 100644 --- a/mongodb_to_sqlite .bat +++ b/mongodb_to_sqlite .bat @@ -63,7 +63,7 @@ REM --- 后续脚本执行 --- REM 运行预处理脚本 python "%~dp0scripts\mongodb_to_sqlite.py" if %ERRORLEVEL% neq 0 ( - echo Error: raw_data_preprocessor.py execution failed. + echo Error: mongodb_to_sqlite.py execution failed. pause exit /b 1 ) From d84b030fa63013d95c210ad3e820a2ed39159a93 Mon Sep 17 00:00:00 2001 From: A0000Xz <122650088+A0000Xz@users.noreply.github.com> Date: Wed, 28 May 2025 18:32:38 +0800 Subject: [PATCH 04/13] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/mongodb_to_sqlite.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index 83e08679..5ff346af 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -258,7 +258,7 @@ class MongoToSQLiteMigrator: "heartflow_data": "heartflow_data_json", "reasoning_data": "reasoning_data_json", }, - unique_fields=["chat_id", "created_at"] + unique_fields=["chat_id", "trigger_text"] ), # 图节点迁移配置 @@ -350,9 +350,6 @@ class MongoToSQLiteMigrator: try: if target_field.name == "record_time" and field_type == "DateTimeField": return datetime.now() - - if target_field.name == "record_time" and field_type == "DateTimeField": - return self._convert_record_time(value) if field_type in ["CharField", "TextField"]: if isinstance(value, (list, dict)): From bc489861d30703fc4062494ef327226c57452485 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 28 May 2025 20:41:46 +0800 Subject: [PATCH 05/13] =?UTF-8?q?feat=EF=BC=9A=E4=B8=BAnormal=5Fchat?= =?UTF-8?q?=E6=8F=90=E4=BE=9B=E9=80=89=E9=A1=B9=EF=BC=8C=E6=9C=89=E6=95=88?= =?UTF-8?q?=E6=8E=A7=E5=88=B6=E5=9B=9E=E5=A4=8D=E9=A2=91=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changelogs/changelog.md | 9 +- .../focus_chat/heartflow_message_processor.py | 18 ++- src/chat/heart_flow/background_tasks.py | 74 +--------- src/chat/heart_flow/heartflow.py | 8 +- src/chat/heart_flow/mai_state_manager.py | 135 ------------------ src/chat/heart_flow/sub_heartflow.py | 9 +- src/chat/heart_flow/subheartflow_manager.py | 5 +- src/chat/normal_chat/normal_chat.py | 66 +++++---- src/chat/normal_chat/normal_chat_utils.py | 33 +++++ src/common/logger.py | 4 +- src/config/official_configs.py | 5 +- template/bot_config_template.toml | 56 ++++---- 12 files changed, 138 insertions(+), 284 deletions(-) delete mode 100644 src/chat/heart_flow/mai_state_manager.py create mode 100644 src/chat/normal_chat/normal_chat_utils.py diff --git a/changelogs/changelog.md b/changelogs/changelog.md index 55d87d40..c4dc8b3b 100644 --- a/changelogs/changelog.md +++ b/changelogs/changelog.md @@ -8,8 +8,8 @@ - 重构表情包模块 - 移除日程系统 -**重构专注聊天(HFC)** -- 模块化HFC,可以自定义不同的部件 +**重构专注聊天(HFC - focus_chat)** +- 模块化设计,可以自定义不同的部件 - 观察器(获取信息) - 信息处理器(处理信息) - 重构:聊天思考(子心流)处理器 @@ -31,6 +31,10 @@ - 在专注模式下,麦麦可以决定自行发送语音消息(需要搭配tts适配器) - 优化reply,减少复读 +**优化普通聊天(normal_chat)** +- 增加了talk_frequency参数来有效控制回复频率 +- 优化了进入和离开normal_chat的方式 + **新增表达方式学习** - 在专注模式下,麦麦可以有独特的表达方式 - 自主学习群聊中的表达方式,更贴近群友 @@ -50,6 +54,7 @@ **人格** - 简化了人格身份的配置 +- 优化了在focus模式下人格的表现和稳定性 **数据库重构** - 移除了默认使用MongoDB,采用轻量sqlite diff --git a/src/chat/focus_chat/heartflow_message_processor.py b/src/chat/focus_chat/heartflow_message_processor.py index c1efeb52..18349d7a 100644 --- a/src/chat/focus_chat/heartflow_message_processor.py +++ b/src/chat/focus_chat/heartflow_message_processor.py @@ -8,6 +8,7 @@ from ..utils.utils import is_mentioned_bot_in_message from src.chat.heart_flow.heartflow import heartflow from src.common.logger_manager import get_logger from ..message_receive.chat_stream import chat_manager +import math # from ..message_receive.message_buffer import message_buffer from ..utils.timer_calculator import Timer @@ -69,6 +70,15 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: message.processed_plain_text, fast_retrieval=True, ) + text_len = len(message.processed_plain_text) + # 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05 + # 采用对数函数实现递减增长 + + base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1)) + base_interest = min(max(base_interest, 0.01), 0.05) + + interested_rate += base_interest + logger.trace(f"记忆激活率: {interested_rate:.2f}") if is_mentioned: @@ -205,17 +215,15 @@ class HeartFCMessageReceiver: # 6. 兴趣度计算与更新 interested_rate, is_mentioned = await _calculate_interest(message) - # await subheartflow.interest_chatting.increase_interest(value=interested_rate) - subheartflow.add_interest_message(message, interested_rate, is_mentioned) + subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) # 7. 日志记录 mes_name = chat.group_info.group_name if chat.group_info else "私聊" - current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time)) + # current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time)) logger.info( - f"[{current_time}][{mes_name}]" + f"[{mes_name}]" f"{userinfo.user_nickname}:" f"{message.processed_plain_text}" - f"[激活: {interested_rate:.1f}]" ) # 8. 关系处理 diff --git a/src/chat/heart_flow/background_tasks.py b/src/chat/heart_flow/background_tasks.py index 9479804e..db260456 100644 --- a/src/chat/heart_flow/background_tasks.py +++ b/src/chat/heart_flow/background_tasks.py @@ -2,25 +2,18 @@ import asyncio import traceback from typing import Optional, Coroutine, Callable, Any, List from src.common.logger_manager import get_logger -from src.chat.heart_flow.mai_state_manager import MaiStateManager, MaiStateInfo from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager from src.config.config import global_config logger = get_logger("background_tasks") -# 新增兴趣评估间隔 -INTEREST_EVAL_INTERVAL_SECONDS = 5 -# 新增聊天超时检查间隔 -NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS = 60 -# 新增状态评估间隔 -HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS = 20 + + # 新增私聊激活检查间隔 PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS = 5 # 与兴趣评估类似,设为5秒 CLEANUP_INTERVAL_SECONDS = 1200 -STATE_UPDATE_INTERVAL_SECONDS = 60 -LOG_INTERVAL_SECONDS = 3 async def _run_periodic_loop( @@ -55,19 +48,13 @@ class BackgroundTaskManager: def __init__( self, - mai_state_info: MaiStateInfo, # Needs current state info - mai_state_manager: MaiStateManager, subheartflow_manager: SubHeartflowManager, ): - self.mai_state_info = mai_state_info - self.mai_state_manager = mai_state_manager self.subheartflow_manager = subheartflow_manager # Task references - self._state_update_task: Optional[asyncio.Task] = None self._cleanup_task: Optional[asyncio.Task] = None self._hf_judge_state_update_task: Optional[asyncio.Task] = None - self._into_focus_task: Optional[asyncio.Task] = None self._private_chat_activation_task: Optional[asyncio.Task] = None # 新增私聊激活任务引用 self._tasks: List[Optional[asyncio.Task]] = [] # Keep track of all tasks @@ -80,15 +67,7 @@ class BackgroundTaskManager: - 将任务引用保存到任务列表 """ - # 任务配置列表: (任务函数, 任务名称, 日志级别, 额外日志信息, 任务对象引用属性名) - task_configs = [ - ( - lambda: self._run_state_update_cycle(STATE_UPDATE_INTERVAL_SECONDS), - "debug", - f"聊天状态更新任务已启动 间隔:{STATE_UPDATE_INTERVAL_SECONDS}s", - "_state_update_task", - ), - ] + task_configs = [] # 根据 chat_mode 条件添加其他任务 if not (global_config.chat.chat_mode == "normal"): @@ -108,12 +87,6 @@ class BackgroundTaskManager: f"私聊激活检查任务已启动 间隔:{PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS}s", "_private_chat_activation_task", ), - # ( - # self._run_into_focus_cycle, - # "debug", # 设为debug,避免过多日志 - # f"专注评估任务已启动 间隔:{INTEREST_EVAL_INTERVAL_SECONDS}s", - # "_into_focus_task", - # ) ] ) else: @@ -163,33 +136,9 @@ class BackgroundTaskManager: # 第三步:清空任务列表 self._tasks = [] # 重置任务列表 - async def _perform_state_update_work(self): - """执行状态更新工作""" - previous_status = self.mai_state_info.get_current_state() - next_state = self.mai_state_manager.check_and_decide_next_state(self.mai_state_info) - - state_changed = False - - if next_state is not None: - state_changed = self.mai_state_info.update_mai_status(next_state) - - # 处理保持离线状态的特殊情况 - if not state_changed and next_state == previous_status == self.mai_state_info.mai_status.OFFLINE: - self.mai_state_info.reset_state_timer() - logger.debug("[后台任务] 保持离线状态并重置计时器") - state_changed = True # 触发后续处理 - - if state_changed: - current_state = self.mai_state_info.get_current_state() - # 状态转换处理 - if ( - current_state == self.mai_state_info.mai_status.OFFLINE - and previous_status != self.mai_state_info.mai_status.OFFLINE - ): - logger.info("检测到离线,停用所有子心流") - await self.subheartflow_manager.deactivate_all_subflows() + async def _perform_cleanup_work(self): """执行子心流清理任务 @@ -216,27 +165,12 @@ class BackgroundTaskManager: # 记录最终清理结果 logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流") - # --- 新增兴趣评估工作函数 --- - # async def _perform_into_focus_work(self): - # """执行一轮子心流兴趣评估与提升检查。""" - # # 直接调用 subheartflow_manager 的方法,并传递当前状态信息 - # await self.subheartflow_manager.sbhf_normal_into_focus() - - async def _run_state_update_cycle(self, interval: int): - await _run_periodic_loop(task_name="State Update", interval=interval, task_func=self._perform_state_update_work) async def _run_cleanup_cycle(self): await _run_periodic_loop( task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work ) - # --- 新增兴趣评估任务运行器 --- - # async def _run_into_focus_cycle(self): - # await _run_periodic_loop( - # task_name="Into Focus", - # interval=INTEREST_EVAL_INTERVAL_SECONDS, - # task_func=self._perform_into_focus_work, - # ) # 新增私聊激活任务运行器 async def _run_private_chat_activation_cycle(self, interval: int): diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index e1f8d957..7e12135a 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -1,7 +1,6 @@ from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState from src.common.logger_manager import get_logger from typing import Any, Optional, List -from src.chat.heart_flow.mai_state_manager import MaiStateInfo, MaiStateManager from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager from src.chat.heart_flow.background_tasks import BackgroundTaskManager # Import BackgroundTaskManager @@ -16,17 +15,12 @@ class Heartflow: """ def __init__(self): - # 状态管理相关 - self.current_state: MaiStateInfo = MaiStateInfo() # 当前状态信息 - self.mai_state_manager: MaiStateManager = MaiStateManager() # 状态决策管理器 # 子心流管理 (在初始化时传入 current_state) - self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state) + self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager() # 后台任务管理器 (整合所有定时任务) self.background_task_manager: BackgroundTaskManager = BackgroundTaskManager( - mai_state_info=self.current_state, - mai_state_manager=self.mai_state_manager, subheartflow_manager=self.subheartflow_manager, ) diff --git a/src/chat/heart_flow/mai_state_manager.py b/src/chat/heart_flow/mai_state_manager.py deleted file mode 100644 index 81f03dec..00000000 --- a/src/chat/heart_flow/mai_state_manager.py +++ /dev/null @@ -1,135 +0,0 @@ -import enum -import time -import random -from typing import List, Tuple, Optional -from src.common.logger_manager import get_logger -from src.manager.mood_manager import mood_manager - -logger = get_logger("mai_state") - - -class MaiState(enum.Enum): - """ - 聊天状态: - OFFLINE: 不在线:回复概率极低,不会进行任何聊天 - NORMAL_CHAT: 正常看手机:回复概率较高,会进行一些普通聊天和少量的专注聊天 - FOCUSED_CHAT: 专注聊天:回复概率极高,会进行专注聊天和少量的普通聊天 - """ - - OFFLINE = "不在线" - NORMAL_CHAT = "正常看手机" - FOCUSED_CHAT = "专心看手机" - - -class MaiStateInfo: - def __init__(self): - self.mai_status: MaiState = MaiState.NORMAL_CHAT # 初始状态改为 NORMAL_CHAT - self.mai_status_history: List[Tuple[MaiState, float]] = [] # 历史状态,包含 状态,时间戳 - self.last_status_change_time: float = time.time() # 状态最后改变时间 - self.last_min_check_time: float = time.time() # 上次1分钟规则检查时间 - - # Mood management is now part of MaiStateInfo - self.mood_manager = mood_manager # Use singleton instance - - def update_mai_status(self, new_status: MaiState) -> bool: - """ - 更新聊天状态。 - - Args: - new_status: 新的 MaiState 状态。 - - Returns: - bool: 如果状态实际发生了改变则返回 True,否则返回 False。 - """ - if new_status != self.mai_status: - self.mai_status = new_status - current_time = time.time() - self.last_status_change_time = current_time - self.last_min_check_time = current_time # Reset 1-min check on any state change - self.mai_status_history.append((new_status, current_time)) - logger.info(f"麦麦状态更新为: {self.mai_status.value}") - return True - else: - return False - - def reset_state_timer(self): - """ - 重置状态持续时间计时器和一分钟规则检查计时器。 - 通常在状态保持不变但需要重新开始计时的情况下调用(例如,保持 OFFLINE)。 - """ - current_time = time.time() - self.last_status_change_time = current_time - self.last_min_check_time = current_time # Also reset the 1-min check timer - logger.debug("MaiStateInfo 状态计时器已重置。") - - def get_mood_prompt(self) -> str: - """获取当前的心情提示词""" - # Delegate to the internal mood manager - return self.mood_manager.get_mood_prompt() - - def get_current_state(self) -> MaiState: - """获取当前的 MaiState""" - return self.mai_status - - -class MaiStateManager: - """管理 Mai 的整体状态转换逻辑""" - - def __init__(self): - pass - - @staticmethod - def check_and_decide_next_state(current_state_info: MaiStateInfo) -> Optional[MaiState]: - """ - 根据当前状态和规则检查是否需要转换状态,并决定下一个状态。 - """ - current_time = time.time() - current_status = current_state_info.mai_status - time_in_current_status = current_time - current_state_info.last_status_change_time - next_state: Optional[MaiState] = None - - def _resolve_offline(candidate_state: MaiState) -> MaiState: - if candidate_state == MaiState.OFFLINE: - return current_status - return candidate_state - - if current_status == MaiState.OFFLINE: - logger.info("当前[离线],没看手机,思考要不要上线看看......") - elif current_status == MaiState.NORMAL_CHAT: - logger.info("当前在[正常看手机]思考要不要继续聊下去......") - elif current_status == MaiState.FOCUSED_CHAT: - logger.info("当前在[专心看手机]思考要不要继续聊下去......") - - if next_state is None: - time_limit_exceeded = False - choices_list = [] - weights = [] - rule_id = "" - - if current_status == MaiState.OFFLINE: - return None - elif current_status == MaiState.NORMAL_CHAT: - if time_in_current_status >= 300: # NORMAL_CHAT 最多持续 300 秒 - time_limit_exceeded = True - rule_id = "2.3 (From NORMAL_CHAT)" - weights = [100] - choices_list = [MaiState.FOCUSED_CHAT] - elif current_status == MaiState.FOCUSED_CHAT: - if time_in_current_status >= 600: # FOCUSED_CHAT 最多持续 600 秒 - time_limit_exceeded = True - rule_id = "2.4 (From FOCUSED_CHAT)" - weights = [100] - choices_list = [MaiState.NORMAL_CHAT] - - if time_limit_exceeded: - next_state_candidate = random.choices(choices_list, weights=weights, k=1)[0] - resolved_candidate = _resolve_offline(next_state_candidate) - logger.debug( - f"规则{rule_id}:时间到,切换到 {next_state_candidate.value},resolve 为 {resolved_candidate.value}" - ) - next_state = resolved_candidate - - if next_state is not None and next_state != current_status: - return next_state - else: - return None diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 0d4ac281..b7267bdc 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -9,7 +9,6 @@ from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import chat_manager from src.chat.focus_chat.heartFC_chat import HeartFChatting from src.chat.normal_chat.normal_chat import NormalChat -from src.chat.heart_flow.mai_state_manager import MaiStateInfo from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo from .utils_chat import get_chat_type_and_target_info from src.config.config import global_config @@ -23,7 +22,6 @@ class SubHeartflow: def __init__( self, subheartflow_id, - mai_states: MaiStateInfo, ): """子心流初始化函数 @@ -36,9 +34,6 @@ class SubHeartflow: self.subheartflow_id = subheartflow_id self.chat_id = subheartflow_id - # 麦麦的状态 - self.mai_states = mai_states - # 这个聊天流的状态 self.chat_state: ChatStateInfo = ChatStateInfo() self.chat_state_changed_time: float = time.time() @@ -334,10 +329,10 @@ class SubHeartflow: return self.normal_chat_instance.get_recent_replies(limit) return [] - def add_interest_message(self, message: MessageRecv, interest_value: float, is_mentioned: bool): + def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool): self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) # 如果字典长度超过10,删除最旧的消息 - if len(self.interest_dict) > 10: + if len(self.interest_dict) > 30: oldest_key = next(iter(self.interest_dict)) self.interest_dict.pop(oldest_key) diff --git a/src/chat/heart_flow/subheartflow_manager.py b/src/chat/heart_flow/subheartflow_manager.py index e3e68f55..bad4393c 100644 --- a/src/chat/heart_flow/subheartflow_manager.py +++ b/src/chat/heart_flow/subheartflow_manager.py @@ -4,7 +4,6 @@ from typing import Dict, Any, Optional, List from src.common.logger_manager import get_logger from src.chat.message_receive.chat_stream import chat_manager from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState -from src.chat.heart_flow.mai_state_manager import MaiStateInfo from src.chat.heart_flow.observation.chatting_observation import ChattingObservation @@ -54,10 +53,9 @@ async def _try_set_subflow_absent_internal(subflow: "SubHeartflow", log_prefix: class SubHeartflowManager: """管理所有活跃的 SubHeartflow 实例。""" - def __init__(self, mai_state_info: MaiStateInfo): + def __init__(self): self.subheartflows: Dict[Any, "SubHeartflow"] = {} self._lock = asyncio.Lock() # 用于保护 self.subheartflows 的访问 - self.mai_state_info: MaiStateInfo = mai_state_info # 存储传入的 MaiStateInfo 实例 async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool: """强制改变指定子心流的状态""" @@ -97,7 +95,6 @@ class SubHeartflowManager: # 初始化子心流, 传入 mai_state_info new_subflow = SubHeartflow( subheartflow_id, - self.mai_state_info, ) # 首先创建并添加聊天观察者 diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 67bc1ab4..cb78c701 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -1,12 +1,9 @@ import asyncio -import statistics # 导入 statistics 模块 import time import traceback from random import random from typing import List, Optional # 导入 Optional - from maim_message import UserInfo, Seg - from src.common.logger_manager import get_logger from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.manager.mood_manager import mood_manager @@ -21,6 +18,7 @@ from src.chat.message_receive.message_sender import message_manager from src.chat.utils.utils_image import image_path_to_base64 from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.normal_chat.willing.willing_manager import willing_manager +from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats from src.config.config import global_config logger = get_logger("normal_chat") @@ -39,6 +37,8 @@ class NormalChat: self.is_group_chat: bool = False self.chat_target_info: Optional[dict] = None + + self.willing_amplifier = 1 # Other sync initializations self.gpt = NormalChatGenerator() @@ -209,10 +209,12 @@ class NormalChat: for msg_id, (message, interest_value, is_mentioned) in items_to_process: try: # 处理消息 + self.adjust_reply_frequency() + await self.normal_response( message=message, is_mentioned=is_mentioned, - interested_rate=interest_value, + interested_rate=interest_value * self.willing_amplifier, rewind_response=False, ) except Exception as e: @@ -228,26 +230,18 @@ class NormalChat: if self._disabled: logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。") return - # 检查收到的消息是否属于当前实例处理的 chat stream - if message.chat_stream.stream_id != self.stream_id: - logger.error( - f"[{self.stream_name}] normal_response 收到不匹配的消息 (来自 {message.chat_stream.stream_id}),预期 {self.stream_id}。已忽略。" - ) - return timing_results = {} - reply_probability = 1.0 if is_mentioned else 0.0 # 如果被提及,基础概率为1,否则需要意愿判断 - + # 意愿管理器:设置当前message信息 - willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate) # 获取回复概率 - is_willing = False + # is_willing = False # 仅在未被提及或基础概率不为1时查询意愿概率 if reply_probability < 1: # 简化逻辑,如果未提及 (reply_probability 为 0),则获取意愿概率 - is_willing = True + # is_willing = True reply_probability = await willing_manager.get_reply_probability(message.message_info.message_id) if message.message_info.additional_config: @@ -257,13 +251,13 @@ class NormalChat: # 打印消息信息 mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊" - current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time)) + # current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time)) # 使用 self.stream_id - willing_log = f"[回复意愿:{await willing_manager.get_willing(self.stream_id):.2f}]" if is_willing else "" + # willing_log = f"[激活值:{await willing_manager.get_willing(self.stream_id):.2f}]" if is_willing else "" logger.info( - f"[{current_time}][{mes_name}]" + f"[{mes_name}]" f"{message.message_info.user_info.user_nickname}:" # 使用 self.chat_stream - f"{message.processed_plain_text}{willing_log}[概率:{reply_probability * 100:.1f}%]" + f"{message.processed_plain_text}[回复概率:{reply_probability * 100:.1f}%]" ) do_reply = False response_set = None # 初始化 response_set @@ -346,16 +340,13 @@ class NormalChat: # 检查是否需要切换到focus模式 await self._check_switch_to_focus() - else: - logger.warning(f"[{self.stream_name}] 思考消息 {thinking_id} 在发送前丢失,无法记录 info_catcher") info_catcher.done_catch() - # 处理表情包 (不再需要传入 chat) with Timer("处理表情包", timing_results): await self._handle_emoji(message, response_set[0]) - # 更新关系情绪 (不再需要传入 chat) + with Timer("关系更新", timing_results): await self._update_relationship(message, response_set) @@ -479,9 +470,6 @@ class NormalChat: # 统计1分钟内的回复数量 recent_reply_count = sum(1 for reply in self.recent_replies if reply["time"] > one_minute_ago) - # print(111111111111111333333333333333333333333331111111111111111111111111111111111) - # print(recent_reply_count) - # 如果1分钟内回复数量大于8,触发切换到focus模式 if recent_reply_count > reply_threshold: logger.info( f"[{self.stream_name}] 检测到1分钟内回复数量({recent_reply_count})大于{reply_threshold},触发切换到focus模式" @@ -491,3 +479,29 @@ class NormalChat: await self.on_switch_to_focus_callback() except Exception as e: logger.error(f"[{self.stream_name}] 触发切换到focus模式时出错: {e}\n{traceback.format_exc()}") + + + def adjust_reply_frequency(self, duration: int = 10): + """ + 调整回复频率 + """ + # 获取最近30分钟内的消息统计 + print(f"willing_amplifier: {self.willing_amplifier}") + stats = get_recent_message_stats(minutes=duration, chat_id=self.stream_id) + bot_reply_count = stats["bot_reply_count"] + print(f"[{self.stream_name}] 最近{duration}分钟内回复数量: {bot_reply_count}") + total_message_count = stats["total_message_count"] + print(f"[{self.stream_name}] 最近{duration}分钟内消息总数: {total_message_count}") + + # 计算回复频率 + _reply_frequency = bot_reply_count / total_message_count + + # 如果回复频率低于0.5,增加回复概率 + if bot_reply_count/duration < global_config.normal_chat.talk_frequency: + # differ = global_config.normal_chat.talk_frequency - reply_frequency + logger.info(f"[{self.stream_name}] 回复频率低于{global_config.normal_chat.talk_frequency},增加回复概率") + self.willing_amplifier += 0.1 + else: + logger.info(f"[{self.stream_name}] 回复频率高于{global_config.normal_chat.talk_frequency},减少回复概率") + self.willing_amplifier -= 0.1 + diff --git a/src/chat/normal_chat/normal_chat_utils.py b/src/chat/normal_chat/normal_chat_utils.py new file mode 100644 index 00000000..2eae4829 --- /dev/null +++ b/src/chat/normal_chat/normal_chat_utils.py @@ -0,0 +1,33 @@ +import time +from src.config.config import global_config +from src.common.message_repository import count_messages + + +def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict: + """ + Args: + minutes (int): 检索的分钟数,默认30分钟 + chat_id (str, optional): 指定的chat_id,仅统计该chat下的消息。为None时统计全部。 + Returns: + dict: {"bot_reply_count": int, "total_message_count": int} + """ + + now = time.time() + start_time = now - minutes * 60 + bot_id = global_config.bot.qq_account + + filter_base = {"time": {"$gte": start_time}} + if chat_id is not None: + filter_base["chat_id"] = chat_id + + # 总消息数 + total_message_count = count_messages(filter_base) + # bot自身回复数 + bot_filter = filter_base.copy() + bot_filter["user_id"] = bot_id + bot_reply_count = count_messages(bot_filter) + + return { + "bot_reply_count": bot_reply_count, + "total_message_count": total_message_count + } \ No newline at end of file diff --git a/src/common/logger.py b/src/common/logger.py index 3ed0fd7f..7258d619 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -271,7 +271,7 @@ CHAT_STYLE_CONFIG = { "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}", }, "simple": { - "console_format": "{time:HH:mm:ss} | 见闻 | {message}", # noqa: E501 + "console_format": "{time:HH:mm:ss} | 见闻 | {message}", # noqa: E501 "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}", }, } @@ -288,7 +288,7 @@ NORMAL_CHAT_STYLE_CONFIG = { "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}", }, "simple": { - "console_format": "{time:HH:mm:ss} | 一般水群 | {message}", # noqa: E501 + "console_format": "{time:HH:mm:ss} | 一般水群 | {message}", # noqa: E501 "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}", }, } diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 5e5199da..e353500f 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -91,7 +91,7 @@ class NormalChatConfig(ConfigBase): max_context_size: int = 15 """上下文长度""" - message_buffer: bool = True + message_buffer: bool = False """消息缓冲器""" emoji_chance: float = 0.2 @@ -103,6 +103,9 @@ class NormalChatConfig(ConfigBase): willing_mode: str = "classical" """意愿模式""" + talk_frequency: float = 1 + """回复频率阈值""" + response_willing_amplifier: float = 1.0 """回复意愿放大系数""" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index c7863340..946e3374 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "2.5.0" +version = "2.6.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -18,22 +18,24 @@ nickname = "麦麦" alias_names = ["麦叠", "牢麦"] #仅在 专注聊天 有效 [personality] -personality_core = "是一个积极向上的女大学生" # 建议20字以内,谁再写3000字小作文敲谁脑袋 +personality_core = "是一个积极向上的女大学生" # 建议50字以内 personality_sides = [ "用一句话或几句话描述人格的一些细节", "用一句话或几句话描述人格的一些细节", "用一句话或几句话描述人格的一些细节", -]# 条数任意,不能为0 +] +# 条数任意,不能为0 # 身份特点 -[identity] #アイデンティティがない 生まれないらららら +#アイデンティティがない 生まれないらららら +[identity] identity_detail = [ "年龄为19岁", "是女孩子", "身高为160cm", "有橙色的短发", ] -# 可以描述外贸,性别,身高,职业,属性等等描述 +# 可以描述外貌,性别,身高,职业,属性等等描述 # 条数任意,不能为0 [relationship] @@ -69,15 +71,18 @@ normal_chat_first_probability = 0.3 # 麦麦回答时选择首要模型的概率 max_context_size = 15 #上下文长度 emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发 thinking_timeout = 120 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢) -message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟 willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现) +talk_frequency = 1 # 麦麦回复频率,一般为1,默认频率下,30分钟麦麦回复30条(约数) + response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1 response_interested_rate_amplifier = 1 # 麦麦回复兴趣度放大系数,听到记忆里的内容时放大系数 -down_frequency_rate = 3 # 降低回复频率的群组回复意愿降低系数 除法 + emoji_response_penalty = 0 # 表情包回复惩罚系数,设为0为不回复单个表情包,减少单独回复表情包的概率 mentioned_bot_inevitable_reply = false # 提及 bot 必然回复 at_bot_inevitable_reply = false # @bot 必然回复 + +down_frequency_rate = 3 # 降低回复频率的群组回复意愿降低系数 除法 talk_frequency_down_groups = [] #降低回复频率的群号码 [focus_chat] #专注聊天 @@ -163,26 +168,8 @@ max_length = 256 # 回复允许的最大长度 max_sentence_num = 4 # 回复允许的最大句子数 enable_kaomoji_protection = false # 是否启用颜文字保护 -[maim_message] -auth_token = [] # 认证令牌,用于API验证,为空则不启用验证 -# 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器 -use_custom = false # 是否启用自定义的maim_message服务器,注意这需要设置新的端口,不能与.env重复 -host="127.0.0.1" -port=8090 -mode="ws" # 支持ws和tcp两种模式 -use_wss = false # 是否使用WSS安全连接,只支持ws模式 -cert_file = "" # SSL证书文件路径,仅在use_wss=true时有效 -key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效 -[telemetry] #发送统计信息,主要是看全球有多少只麦麦 -enable = true - -[experimental] #实验性功能 -debug_show_chat_mode = true # 是否在回复后显示当前聊天模式 -enable_friend_chat = false # 是否启用好友聊天 -pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与回复模式独立,在0.7.0暂时无效 - #下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写 # stream = : 用于指定模型是否是使用流式输出 @@ -333,5 +320,24 @@ pri_out = 8 +[maim_message] +auth_token = [] # 认证令牌,用于API验证,为空则不启用验证 +# 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器 +use_custom = false # 是否启用自定义的maim_message服务器,注意这需要设置新的端口,不能与.env重复 +host="127.0.0.1" +port=8090 +mode="ws" # 支持ws和tcp两种模式 +use_wss = false # 是否使用WSS安全连接,只支持ws模式 +cert_file = "" # SSL证书文件路径,仅在use_wss=true时有效 +key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效 + +[telemetry] #发送统计信息,主要是看全球有多少只麦麦 +enable = true + +[experimental] #实验性功能 +debug_show_chat_mode = true # 是否在回复后显示当前聊天模式 +enable_friend_chat = false # 是否启用好友聊天 +pfc_chatting = false # 暂时无效 + From 218d0d4a5dfe1948e3d494a83b413b68f4ba0a1b Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 28 May 2025 20:44:26 +0800 Subject: [PATCH 06/13] pass ruff --- .../expressors/default_expressor.py | 19 ++++++++--------- src/chat/focus_chat/heartFC_chat.py | 4 +++- .../focus_chat/heartflow_message_processor.py | 13 ++++-------- src/chat/heart_flow/background_tasks.py | 8 +------ src/chat/heart_flow/heartflow.py | 1 - src/chat/heart_flow/sub_heartflow.py | 1 + src/chat/message_receive/storage.py | 2 +- src/chat/normal_chat/normal_chat.py | 21 +++++++------------ src/chat/normal_chat/normal_chat_generator.py | 10 ++++----- src/chat/normal_chat/normal_chat_utils.py | 5 +---- src/chat/normal_chat/normal_prompt.py | 1 - src/chat/utils/chat_message_builder.py | 5 +++-- src/config/official_configs.py | 6 ++---- src/plugins/test_plugin/actions/__init__.py | 1 + .../test_plugin/actions/mute_action.py | 9 ++++---- template/bot_config_template.toml | 2 +- 16 files changed, 45 insertions(+), 63 deletions(-) diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index cbea0a22..382d20aa 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -16,7 +16,6 @@ from src.chat.utils.info_catcher import info_catcher_manager from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.message_receive.chat_stream import ChatStream from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp -from src.individuality.individuality import individuality from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat import time @@ -106,10 +105,7 @@ class DefaultExpressor: user_nickname=global_config.bot.nickname, platform=messageinfo.platform, ) - # logger.debug(f"创建思考消息:{anchor_message}") - # logger.debug(f"创建思考消息chat:{chat}") - # logger.debug(f"创建思考消息bot_user_info:{bot_user_info}") - # logger.debug(f"创建思考消息messageinfo:{messageinfo}") + thinking_message = MessageThinking( message_id=thinking_id, chat_stream=chat, @@ -281,14 +277,14 @@ class DefaultExpressor: in_mind_reply, target_message, ) -> str: - prompt_personality = individuality.get_prompt(x_person=0, level=2) + # prompt_personality = individuality.get_prompt(x_person=0, level=2) # Determine if it's a group chat is_group_chat = bool(chat_stream.group_info) # Use sender_name passed from caller for private chat, otherwise use a default for group # Default sender_name for group chat isn't used in the group prompt template, but set for consistency - effective_sender_name = sender_name if not is_group_chat else "某人" + # effective_sender_name = sender_name if not is_group_chat else "某人" message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, @@ -377,7 +373,11 @@ class DefaultExpressor: # --- 发送器 (Sender) --- # async def send_response_messages( - self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str = "", display_message: str = "" + self, + anchor_message: Optional[MessageRecv], + response_set: List[Tuple[str, str]], + thinking_id: str = "", + display_message: str = "", ) -> Optional[MessageSending]: """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" chat = self.chat_stream @@ -412,10 +412,9 @@ class DefaultExpressor: # 为每个消息片段生成唯一ID type = msg_text[0] data = msg_text[1] - + if global_config.experimental.debug_show_chat_mode and type == "text": data += "ᶠ" - part_message_id = f"{thinking_id}_{i}" message_segment = Seg(type=type, data=data) diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index ab061d85..63ce5c30 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -379,12 +379,14 @@ class HeartFChatting: for processor in self.processors: processor_name = processor.__class__.log_prefix + # 用lambda包裹,便于传参 async def run_with_timeout(proc=processor): return await asyncio.wait_for( proc.process_info(observations=observations, running_memorys=running_memorys), - timeout=PROCESSOR_TIMEOUT + timeout=PROCESSOR_TIMEOUT, ) + task = asyncio.create_task(run_with_timeout()) processor_tasks.append(task) task_to_name_map[task] = processor_name diff --git a/src/chat/focus_chat/heartflow_message_processor.py b/src/chat/focus_chat/heartflow_message_processor.py index 18349d7a..8277e69f 100644 --- a/src/chat/focus_chat/heartflow_message_processor.py +++ b/src/chat/focus_chat/heartflow_message_processor.py @@ -1,4 +1,3 @@ -import time import traceback from ..memory_system.Hippocampus import HippocampusManager from ...config.config import global_config @@ -73,12 +72,12 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: text_len = len(message.processed_plain_text) # 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05 # 采用对数函数实现递减增长 - + base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1)) base_interest = min(max(base_interest, 0.01), 0.05) - + interested_rate += base_interest - + logger.trace(f"记忆激活率: {interested_rate:.2f}") if is_mentioned: @@ -220,11 +219,7 @@ class HeartFCMessageReceiver: # 7. 日志记录 mes_name = chat.group_info.group_name if chat.group_info else "私聊" # current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time)) - logger.info( - f"[{mes_name}]" - f"{userinfo.user_nickname}:" - f"{message.processed_plain_text}" - ) + logger.info(f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}") # 8. 关系处理 if global_config.relationship.give_name: diff --git a/src/chat/heart_flow/background_tasks.py b/src/chat/heart_flow/background_tasks.py index db260456..4bacfd0a 100644 --- a/src/chat/heart_flow/background_tasks.py +++ b/src/chat/heart_flow/background_tasks.py @@ -8,8 +8,6 @@ from src.config.config import global_config logger = get_logger("background_tasks") - - # 新增私聊激活检查间隔 PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS = 5 # 与兴趣评估类似,设为5秒 @@ -136,9 +134,7 @@ class BackgroundTaskManager: # 第三步:清空任务列表 self._tasks = [] # 重置任务列表 - # 状态转换处理 - - + # 状态转换处理 async def _perform_cleanup_work(self): """执行子心流清理任务 @@ -165,13 +161,11 @@ class BackgroundTaskManager: # 记录最终清理结果 logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流") - async def _run_cleanup_cycle(self): await _run_periodic_loop( task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work ) - # 新增私聊激活任务运行器 async def _run_private_chat_activation_cycle(self, interval: int): await _run_periodic_loop( diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index 7e12135a..d58c5cde 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -15,7 +15,6 @@ class Heartflow: """ def __init__(self): - # 子心流管理 (在初始化时传入 current_state) self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager() diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index b7267bdc..c95f606b 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -18,6 +18,7 @@ logger = get_logger("sub_heartflow") install(extra_lines=3) + class SubHeartflow: def __init__( self, diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 54dd68f0..8c05a9ab 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -24,7 +24,7 @@ class MessageStorage: else: filtered_processed_plain_text = "" - if isinstance(message,MessageSending): + if isinstance(message, MessageSending): display_message = message.display_message if display_message: filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index cb78c701..98478183 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -37,7 +37,7 @@ class NormalChat: self.is_group_chat: bool = False self.chat_target_info: Optional[dict] = None - + self.willing_amplifier = 1 # Other sync initializations @@ -56,7 +56,6 @@ class NormalChat: self._disabled = False # 增加停用标志 - async def initialize(self): """异步初始化,获取聊天类型和目标信息。""" if self._initialized: @@ -210,7 +209,7 @@ class NormalChat: try: # 处理消息 self.adjust_reply_frequency() - + await self.normal_response( message=message, is_mentioned=is_mentioned, @@ -233,7 +232,7 @@ class NormalChat: timing_results = {} reply_probability = 1.0 if is_mentioned else 0.0 # 如果被提及,基础概率为1,否则需要意愿判断 - + # 意愿管理器:设置当前message信息 willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate) @@ -306,7 +305,7 @@ class NormalChat: return # 不执行后续步骤 logger.info(f"[{self.stream_name}] 回复内容: {response_set}") - + if self._disabled: logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。") return @@ -340,13 +339,11 @@ class NormalChat: # 检查是否需要切换到focus模式 await self._check_switch_to_focus() - info_catcher.done_catch() with Timer("处理表情包", timing_results): await self._handle_emoji(message, response_set[0]) - with Timer("关系更新", timing_results): await self._update_relationship(message, response_set) @@ -479,8 +476,7 @@ class NormalChat: await self.on_switch_to_focus_callback() except Exception as e: logger.error(f"[{self.stream_name}] 触发切换到focus模式时出错: {e}\n{traceback.format_exc()}") - - + def adjust_reply_frequency(self, duration: int = 10): """ 调整回复频率 @@ -492,16 +488,15 @@ class NormalChat: print(f"[{self.stream_name}] 最近{duration}分钟内回复数量: {bot_reply_count}") total_message_count = stats["total_message_count"] print(f"[{self.stream_name}] 最近{duration}分钟内消息总数: {total_message_count}") - + # 计算回复频率 _reply_frequency = bot_reply_count / total_message_count - + # 如果回复频率低于0.5,增加回复概率 - if bot_reply_count/duration < global_config.normal_chat.talk_frequency: + if bot_reply_count / duration < global_config.normal_chat.talk_frequency: # differ = global_config.normal_chat.talk_frequency - reply_frequency logger.info(f"[{self.stream_name}] 回复频率低于{global_config.normal_chat.talk_frequency},增加回复概率") self.willing_amplifier += 0.1 else: logger.info(f"[{self.stream_name}] 回复频率高于{global_config.normal_chat.talk_frequency},减少回复概率") self.willing_amplifier -= 0.1 - diff --git a/src/chat/normal_chat/normal_chat_generator.py b/src/chat/normal_chat/normal_chat_generator.py index e0a88b4e..2ad1a197 100644 --- a/src/chat/normal_chat/normal_chat_generator.py +++ b/src/chat/normal_chat/normal_chat_generator.py @@ -64,11 +64,12 @@ class NormalChatGenerator: async def _generate_response_with_model(self, message: MessageThinking, model: LLMRequest, thinking_id: str): info_catcher = info_catcher_manager.get_info_catcher(thinking_id) + person_id = person_info_manager.get_person_id( + message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id + ) - person_id = person_info_manager.get_person_id(message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id) - person_name = await person_info_manager.get_value(person_id, "person_name") - + if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: sender_name = ( f"[{message.chat_stream.user_info.user_nickname}]" @@ -78,8 +79,7 @@ class NormalChatGenerator: sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})" else: sender_name = f"用户({message.chat_stream.user_info.user_id})" - - + # 构建prompt with Timer() as t_build_prompt: prompt = await prompt_builder.build_prompt( diff --git a/src/chat/normal_chat/normal_chat_utils.py b/src/chat/normal_chat/normal_chat_utils.py index 2eae4829..2ebd3bda 100644 --- a/src/chat/normal_chat/normal_chat_utils.py +++ b/src/chat/normal_chat/normal_chat_utils.py @@ -27,7 +27,4 @@ def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict: bot_filter["user_id"] = bot_id bot_reply_count = count_messages(bot_filter) - return { - "bot_reply_count": bot_reply_count, - "total_message_count": total_message_count - } \ No newline at end of file + return {"bot_reply_count": bot_reply_count, "total_message_count": total_message_count} diff --git a/src/chat/normal_chat/normal_prompt.py b/src/chat/normal_chat/normal_prompt.py index 365d42de..54624a02 100644 --- a/src/chat/normal_chat/normal_prompt.py +++ b/src/chat/normal_chat/normal_prompt.py @@ -17,7 +17,6 @@ logger = get_logger("prompt") def init_prompt(): - Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1") Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1") Prompt("在群里聊天", "chat_target_group2") diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 38c9e9a8..66b58e25 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -10,6 +10,7 @@ from rich.traceback import install install(extra_lines=3) + def get_raw_msg_by_timestamp( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: @@ -198,7 +199,7 @@ async def _build_readable_messages_internal( content = msg.get("display_message") else: content = msg.get("processed_plain_text", "") # 默认空字符串 - + if "ᶠ" in content: content = content.replace("ᶠ", "") if "ⁿ" in content: @@ -465,7 +466,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: content = msg.get("display_message") else: content = msg.get("processed_plain_text", "") - + if "ᶠ" in content: content = content.replace("ᶠ", "") if "ⁿ" in content: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index e353500f..8f98256e 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -59,7 +59,7 @@ class ChatConfig(ConfigBase): chat_mode: str = "normal" """聊天模式""" - + auto_focus_threshold: float = 1.0 """自动切换到专注聊天的阈值,越低越容易进入专注聊天""" @@ -132,8 +132,6 @@ class NormalChatConfig(ConfigBase): class FocusChatConfig(ConfigBase): """专注聊天配置类""" - - observation_context_size: int = 12 """可观察到的最长上下文大小,超过这个值的上下文会被压缩""" @@ -346,7 +344,7 @@ class TelemetryConfig(ConfigBase): class ExperimentalConfig(ConfigBase): """实验功能配置类""" - debug_show_chat_mode: bool = True + debug_show_chat_mode: bool = False """是否在回复后显示当前聊天模式""" enable_friend_chat: bool = False diff --git a/src/plugins/test_plugin/actions/__init__.py b/src/plugins/test_plugin/actions/__init__.py index 6797e51b..7d96ea8a 100644 --- a/src/plugins/test_plugin/actions/__init__.py +++ b/src/plugins/test_plugin/actions/__init__.py @@ -2,5 +2,6 @@ # 导入所有动作模块以确保装饰器被执行 from . import test_action # noqa + # from . import online_action # noqa from . import mute_action # noqa diff --git a/src/plugins/test_plugin/actions/mute_action.py b/src/plugins/test_plugin/actions/mute_action.py index 4f457d90..df112a16 100644 --- a/src/plugins/test_plugin/actions/mute_action.py +++ b/src/plugins/test_plugin/actions/mute_action.py @@ -57,14 +57,15 @@ class MuteAction(PluginAction): # 确保duration是字符串类型 if int(duration) < 60: duration = 60 - if int(duration) > 3600*24*30: - duration = 3600*24*30 + if int(duration) > 3600 * 24 * 30: + duration = 3600 * 24 * 30 duration_str = str(int(duration)) # 发送群聊禁言命令,按照新格式 await self.send_message( - type = "command", data = {"name": "GROUP_BAN", "args": {"qq_id": str(user_id), "duration": duration_str}}, - display_message = f"我 禁言了 {target} {duration_str}秒" + type="command", + data={"name": "GROUP_BAN", "args": {"qq_id": str(user_id), "duration": duration_str}}, + display_message=f"我 禁言了 {target} {duration_str}秒", ) logger.info(f"{self.log_prefix} 成功发送禁言命令,用户 {target}({user_id}),时长 {duration} 秒") diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 946e3374..2d4eccb9 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -335,7 +335,7 @@ key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效 enable = true [experimental] #实验性功能 -debug_show_chat_mode = true # 是否在回复后显示当前聊天模式 +debug_show_chat_mode = false # 是否在回复后显示当前聊天模式 enable_friend_chat = false # 是否启用好友聊天 pfc_chatting = false # 暂时无效 From 859e5201fce76120cb12a303ed770443144ea36f Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 28 May 2025 20:45:34 +0800 Subject: [PATCH 07/13] Update default_expressor.py --- src/chat/focus_chat/expressors/default_expressor.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index 382d20aa..eceadf67 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -277,15 +277,8 @@ class DefaultExpressor: in_mind_reply, target_message, ) -> str: - # prompt_personality = individuality.get_prompt(x_person=0, level=2) - - # Determine if it's a group chat is_group_chat = bool(chat_stream.group_info) - # Use sender_name passed from caller for private chat, otherwise use a default for group - # Default sender_name for group chat isn't used in the group prompt template, but set for consistency - # effective_sender_name = sender_name if not is_group_chat else "某人" - message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), From 184a74bb8b18d6df2775bb59cfd80b07a07e8724 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 28 May 2025 20:50:18 +0800 Subject: [PATCH 08/13] 1 --- mongodb_to_sqlite .bat | 142 ++++++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite .bat index 5c029e44..f960e508 100644 --- a/mongodb_to_sqlite .bat +++ b/mongodb_to_sqlite .bat @@ -1,72 +1,72 @@ -@echo off -CHCP 65001 > nul -setlocal enabledelayedexpansion - -echo 你需要选择启动方式,输入字母来选择: -echo V = 不知道什么意思就输入 V -echo C = 输入 C 使用 Conda 环境 -echo. -choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V - -set "ENV_TYPE=" -if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA" -if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV" - -if "%ENV_TYPE%" == "CONDA" goto activate_conda -if "%ENV_TYPE%" == "VENV" goto activate_venv - -REM 如果 choice 超时或返回意外值,默认使用 venv -echo WARN: Invalid selection or timeout from choice. Defaulting to VENV. -set "ENV_TYPE=VENV" -goto activate_venv - -:activate_conda - set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: " - if not defined CONDA_ENV_NAME ( - echo 错误: 未输入 Conda 环境名称. - pause - exit /b 1 - ) - echo 选择: Conda '!CONDA_ENV_NAME!' - REM 激活Conda环境 - call conda activate !CONDA_ENV_NAME! - if !ERRORLEVEL! neq 0 ( - echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在. - pause - exit /b 1 - ) - goto env_activated - -:activate_venv - echo Selected: venv (default or selected) - REM 查找venv虚拟环境 - set "venv_path=%~dp0venv\Scripts\activate.bat" - if not exist "%venv_path%" ( - echo Error: venv not found. Ensure the venv directory exists alongside the script. - pause - exit /b 1 - ) - REM 激活虚拟环境 - call "%venv_path%" - if %ERRORLEVEL% neq 0 ( - echo Error: Failed to activate venv virtual environment. - pause - exit /b 1 - ) - goto env_activated - -:env_activated -echo Environment activated successfully! - -REM --- 后续脚本执行 --- - -REM 运行预处理脚本 -python "%~dp0scripts\mongodb_to_sqlite.py" -if %ERRORLEVEL% neq 0 ( - echo Error: mongodb_to_sqlite.py execution failed. - pause - exit /b 1 -) - -echo All processing steps completed! +@echo off +CHCP 65001 > nul +setlocal enabledelayedexpansion + +echo 你需要选择启动方式,输入字母来选择: +echo V = 不知道什么意思就输入 V +echo C = 输入 C 使用 Conda 环境 +echo. +choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V + +set "ENV_TYPE=" +if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA" +if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV" + +if "%ENV_TYPE%" == "CONDA" goto activate_conda +if "%ENV_TYPE%" == "VENV" goto activate_venv + +REM 如果 choice 超时或返回意外值,默认使用 venv +echo WARN: Invalid selection or timeout from choice. Defaulting to VENV. +set "ENV_TYPE=VENV" +goto activate_venv + +:activate_conda + set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: " + if not defined CONDA_ENV_NAME ( + echo 错误: 未输入 Conda 环境名称. + pause + exit /b 1 + ) + echo 选择: Conda '!CONDA_ENV_NAME!' + REM 激活Conda环境 + call conda activate !CONDA_ENV_NAME! + if !ERRORLEVEL! neq 0 ( + echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在. + pause + exit /b 1 + ) + goto env_activated + +:activate_venv + echo Selected: venv (default or selected) + REM 查找venv虚拟环境 + set "venv_path=%~dp0venv\Scripts\activate.bat" + if not exist "%venv_path%" ( + echo Error: venv not found. Ensure the venv directory exists alongside the script. + pause + exit /b 1 + ) + REM 激活虚拟环境 + call "%venv_path%" + if %ERRORLEVEL% neq 0 ( + echo Error: Failed to activate venv virtual environment. + pause + exit /b 1 + ) + goto env_activated + +:env_activated +echo Environment activated successfully! + +REM --- 后续脚本执行 --- + +REM 运行预处理脚本 +python "%~dp0scripts\mongodb_to_sqlite.py" +if %ERRORLEVEL% neq 0 ( + echo Error: mongodb_to_sqlite.py execution failed. + pause + exit /b 1 +) + +echo All processing steps completed! pause \ No newline at end of file From 460c7fb75ac839737f6cf7afc443d96a4b350271 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 28 May 2025 20:51:38 +0800 Subject: [PATCH 09/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=BF=81=E7=A7=BB?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=EF=BC=8C=E6=94=AF=E6=8C=81=E6=96=AD=E7=82=B9?= =?UTF-8?q?=E7=BB=AD=E4=BC=A0=E5=92=8C=E6=89=B9=E9=87=8F=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=95=B0=E6=8D=AE=E5=BA=93=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=EF=BC=8C=E5=85=81=E8=AE=B8=E7=BE=A4=E8=81=8A=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E5=AD=97=E6=AE=B5=E4=B8=BA=E5=8F=AF=E9=80=89=EF=BC=88?= =?UTF-8?q?null=3DTrue=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/mongodb_to_sqlite.py | 485 ++++++++++++++++++++------ src/common/database/database_model.py | 6 +- 2 files changed, 374 insertions(+), 117 deletions(-) diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index 5ff346af..609906fa 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -1,6 +1,9 @@ import os import json import sys # 新增系统模块导入 +# import time +import pickle +from pathlib import Path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from typing import Dict, Any, List, Optional, Type @@ -10,16 +13,31 @@ from pymongo import MongoClient from pymongo.errors import ConnectionFailure from peewee import Model, Field, IntegrityError +# Rich 进度条和显示组件 +from rich.console import Console +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeRemainingColumn, + TimeElapsedColumn, + SpinnerColumn +) +from rich.table import Table +from rich.panel import Panel +# from rich.text import Text + from src.common.database.database import db from src.common.database.database_model import ( ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions, - OnlineTime, PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges + PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges ) from src.common.logger_manager import get_logger logger = get_logger("mongodb_to_sqlite") - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @dataclass class MigrationConfig: """迁移配置类""" @@ -32,6 +50,19 @@ class MigrationConfig: unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段 +# 数据验证相关类已移除 - 用户要求不要数据验证 + + +@dataclass +class MigrationCheckpoint: + """迁移断点数据""" + collection_name: str + processed_count: int + last_processed_id: Any + timestamp: datetime + batch_errors: List[Dict[str, Any]] = field(default_factory=list) + + @dataclass class MigrationStats: """迁移统计信息""" @@ -41,7 +72,11 @@ class MigrationStats: error_count: int = 0 skipped_count: int = 0 duplicate_count: int = 0 + validation_errors: int = 0 + batch_insert_count: int = 0 errors: List[Dict[str, Any]] = field(default_factory=list) + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None): """添加错误记录""" @@ -52,6 +87,11 @@ class MigrationStats: 'doc_data': doc_data }) self.error_count += 1 + + def add_validation_error(self, doc_id: Any, field: str, error: str): + """添加验证错误""" + self.add_error(doc_id, f"验证失败 - {field}: {error}") + self.validation_errors += 1 class MongoToSQLiteMigrator: @@ -65,6 +105,15 @@ class MongoToSQLiteMigrator: # 迁移配置 self.migration_configs = self._initialize_migration_configs() + + # 进度条控制台 + self.console = Console() + # 检查点目录 + self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints")) + self.checkpoint_dir.mkdir(exist_ok=True) + + # 验证规则已禁用 + self.validation_rules = self._initialize_validation_rules() def _build_mongo_uri(self) -> str: """构建MongoDB连接URI""" @@ -84,8 +133,7 @@ class MongoToSQLiteMigrator: def _initialize_migration_configs(self) -> List[MigrationConfig]: """初始化迁移配置""" - return [ - # 表情包迁移配置 + return [ # 表情包迁移配置 MigrationConfig( mongo_collection="emoji", target_model=Emoji, @@ -96,13 +144,13 @@ class MongoToSQLiteMigrator: "description": "description", "emotion": "emotion", "usage_count": "usage_count", - "last_used_time": "last_used_time", - "last_used_time": "record_time" # 这个纯粹是为了应付整体映射格式,实际上直接用当前时间戳填了record_time + "last_used_time": "last_used_time" + # record_time字段将在转换时自动设置为当前时间 }, + enable_validation=False, # 禁用数据验证 unique_fields=["full_path", "emoji_hash"] ), - - # 聊天流迁移配置 + # 聊天流迁移配置 MigrationConfig( mongo_collection="chat_streams", target_model=ChatStreams, @@ -119,10 +167,10 @@ class MongoToSQLiteMigrator: "user_info.user_nickname": "user_nickname", "user_info.user_cardname": "user_cardname" }, + enable_validation=False, # 禁用数据验证 unique_fields=["stream_id"] ), - - # LLM使用记录迁移配置 + # LLM使用记录迁移配置 MigrationConfig( mongo_collection="llm_usage", target_model=LLMUsage, @@ -138,10 +186,10 @@ class MongoToSQLiteMigrator: "status": "status", "timestamp": "timestamp" }, + enable_validation=False, # 禁用数据验证 unique_fields=["user_id", "timestamp"] # 组合唯一性 ), - - # 消息迁移配置 + # 消息迁移配置 MigrationConfig( mongo_collection="messages", target_model=Messages, @@ -168,6 +216,7 @@ class MongoToSQLiteMigrator: "detailed_plain_text": "detailed_plain_text", "memorized_times": "memorized_times" }, + enable_validation=False, # 禁用数据验证 unique_fields=["message_id"] ), @@ -194,21 +243,7 @@ class MongoToSQLiteMigrator: "hash": "image_description_hash", "description": "description", "timestamp": "timestamp" - }, - unique_fields=["image_description_hash", "type"] - ), - - # 在线时长迁移配置 - MigrationConfig( - mongo_collection="online_time", - target_model=OnlineTime, - field_mapping={ - "timestamp": "timestamp", - "duration": "duration", - "start_timestamp": "start_timestamp", - "end_timestamp": "end_timestamp" - }, - unique_fields=["start_timestamp", "end_timestamp"] + }, unique_fields=["image_description_hash", "type"] ), # 个人信息迁移配置 @@ -290,6 +325,9 @@ class MongoToSQLiteMigrator: unique_fields=["source", "target"] # 组合唯一性 ) ] + def _initialize_validation_rules(self) -> Dict[str, Any]: + """数据验证已禁用 - 返回空字典""" + return {} def connect_mongodb(self) -> bool: """连接到MongoDB""" @@ -412,11 +450,69 @@ class MongoToSQLiteMigrator: elif field_type in ["FloatField", "DoubleField"]: return 0.0 elif field_type == "BooleanField": - return False + return False elif field_type == "DateTimeField": return datetime.now() return None + def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: + """数据验证已禁用 - 始终返回True""" + return True + + def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any): + """保存迁移断点""" + checkpoint = MigrationCheckpoint( + collection_name=collection_name, + processed_count=processed_count, + last_processed_id=last_id, + timestamp=datetime.now() + ) + + checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" + try: + with open(checkpoint_file, 'wb') as f: + pickle.dump(checkpoint, f) + except Exception as e: + logger.warning(f"保存断点失败: {e}") + + def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]: + """加载迁移断点""" + checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" + if not checkpoint_file.exists(): + return None + + try: + with open(checkpoint_file, 'rb') as f: + return pickle.load(f) + except Exception as e: + logger.warning(f"加载断点失败: {e}") + return None + + def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int: + """批量插入数据""" + if not data_list: + return 0 + + success_count = 0 + try: + with db.atomic(): + # 分批插入,避免SQL语句过长 + batch_size = 100 + for i in range(0, len(data_list), batch_size): + batch = data_list[i:i + batch_size] + model.insert_many(batch).execute() + success_count += len(batch) + except Exception as e: + logger.error(f"批量插入失败: {e}") + # 如果批量插入失败,尝试逐个插入 + for data in data_list: + try: + model.create(**data) + success_count += 1 + except Exception: + pass # 忽略单个插入失败 + + return success_count def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]) -> bool: @@ -458,17 +554,30 @@ class MongoToSQLiteMigrator: except Exception as e: logger.error(f"创建模型实例失败: {e}") return None - def migrate_collection(self, config: MigrationConfig) -> MigrationStats: - """迁移单个集合 - 使用ORM方式""" + """迁移单个集合 - 使用优化的批量插入和进度条""" stats = MigrationStats() + stats.start_time = datetime.now() + + # 检查是否有断点 + checkpoint = self._load_checkpoint(config.mongo_collection) + start_from_id = checkpoint.last_processed_id if checkpoint else None + if checkpoint: + stats.processed_count = checkpoint.processed_count + logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录") logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}") try: # 获取MongoDB集合 mongo_collection = self.mongo_db[config.mongo_collection] - stats.total_documents = mongo_collection.count_documents({}) + + # 构建查询条件(用于断点恢复) + query = {} + if start_from_id: + query = {"_id": {"$gt": start_from_id}} + + stats.total_documents = mongo_collection.count_documents(query) if stats.total_documents == 0: logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移") @@ -476,69 +585,112 @@ class MongoToSQLiteMigrator: logger.info(f"待迁移文档数量: {stats.total_documents}") - # 逐个处理文档 - batch_count = 0 - for mongo_doc in mongo_collection.find().batch_size(config.batch_size): - try: - stats.processed_count += 1 - doc_id = mongo_doc.get('_id', 'unknown') - - # 构建目标数据 - target_data = {} - for mongo_field, sqlite_field in config.field_mapping.items(): - value = self._get_nested_value(mongo_doc, mongo_field) - - # 获取目标字段对象并转换类型 - if hasattr(config.target_model, sqlite_field): - field_obj = getattr(config.target_model, sqlite_field) - converted_value = self._convert_field_value(value, field_obj) - target_data[sqlite_field] = converted_value - - # 重复检查 - if config.skip_duplicates and self._check_duplicate_by_unique_fields( - config.target_model, target_data, config.unique_fields - ): - stats.duplicate_count += 1 - stats.skipped_count += 1 - logger.debug(f"跳过重复记录: {doc_id}") - continue - - # 使用ORM创建实例 - with db.atomic(): # 每个实例的事务保护 - instance = self._create_model_instance(config.target_model, target_data) - - if instance: - stats.success_count += 1 - else: - stats.add_error(doc_id, "ORM创建实例失败", target_data) + # 创建Rich进度条 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=self.console, + refresh_per_second=10 + ) as progress: - except Exception as e: - doc_id = mongo_doc.get('_id', 'unknown') - stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) - logger.error(f"处理文档失败 (ID: {doc_id}): {e}") + task = progress.add_task( + f"迁移 {config.mongo_collection}", + total=stats.total_documents + ) + # 批量处理数据 + batch_data = [] + batch_count = 0 + last_processed_id = None - # 进度报告 - batch_count += 1 - if batch_count % config.batch_size == 0: - progress = (stats.processed_count / stats.total_documents) * 100 - logger.info( - f"迁移进度: {stats.processed_count}/{stats.total_documents} " - f"({progress:.1f}%) - 成功: {stats.success_count}, " - f"错误: {stats.error_count}, 跳过: {stats.skipped_count}" - ) + for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size): + try: + doc_id = mongo_doc.get('_id', 'unknown') + last_processed_id = doc_id + + # 构建目标数据 + target_data = {} + for mongo_field, sqlite_field in config.field_mapping.items(): + value = self._get_nested_value(mongo_doc, mongo_field) + + # 获取目标字段对象并转换类型 + if hasattr(config.target_model, sqlite_field): + field_obj = getattr(config.target_model, sqlite_field) + converted_value = self._convert_field_value(value, field_obj) + target_data[sqlite_field] = converted_value + + # 数据验证已禁用 + # if config.enable_validation: + # if not self._validate_data(config.mongo_collection, target_data, doc_id, stats): + # stats.skipped_count += 1 + # continue + + # 重复检查 + if config.skip_duplicates and self._check_duplicate_by_unique_fields( + config.target_model, target_data, config.unique_fields + ): + stats.duplicate_count += 1 + stats.skipped_count += 1 + logger.debug(f"跳过重复记录: {doc_id}") + continue + + # 添加到批量数据 + batch_data.append(target_data) + stats.processed_count += 1 + + # 执行批量插入 + if len(batch_data) >= config.batch_size: + success_count = self._batch_insert(config.target_model, batch_data) + stats.success_count += success_count + stats.batch_insert_count += 1 + + # 保存断点 + self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id) + + batch_data.clear() + batch_count += 1 + + # 更新进度条 + progress.update(task, advance=config.batch_size) + + except Exception as e: + doc_id = mongo_doc.get('_id', 'unknown') + stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) + logger.error(f"处理文档失败 (ID: {doc_id}): {e}") + + # 处理剩余的批量数据 + if batch_data: + success_count = self._batch_insert(config.target_model, batch_data) + stats.success_count += success_count + stats.batch_insert_count += 1 + progress.update(task, advance=len(batch_data)) + + # 完成进度条 + progress.update(task, completed=stats.total_documents) + + stats.end_time = datetime.now() + duration = stats.end_time - stats.start_time logger.info( f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n" f"总计: {stats.total_documents}, 成功: {stats.success_count}, " - f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}" + f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n" + f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}" ) + # 清理断点文件 + checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl" + if checkpoint_file.exists(): + checkpoint_file.unlink() + except Exception as e: logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}") stats.add_error("collection_error", str(e)) return stats - def migrate_all(self) -> Dict[str, MigrationStats]: """执行所有迁移任务""" logger.info("开始执行数据库迁移...") @@ -550,18 +702,44 @@ class MongoToSQLiteMigrator: all_stats = {} try: - for config in self.migration_configs: - logger.info(f"\n开始处理集合: {config.mongo_collection}") + # 创建总体进度表格 + total_collections = len(self.migration_configs) + self.console.print(Panel( + f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n" + f"[yellow]总集合数: {total_collections}[/yellow]", + title="迁移开始", + expand=False + )) + for idx, config in enumerate(self.migration_configs, 1): + self.console.print(f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]") stats = self.migrate_collection(config) all_stats[config.mongo_collection] = stats + # 显示单个集合的快速统计 + if stats.processed_count > 0: + success_rate = (stats.success_count / stats.processed_count * 100) + if success_rate >= 95: + status_emoji = "✅" + status_color = "bright_green" + elif success_rate >= 80: + status_emoji = "⚠️" + status_color = "yellow" + else: + status_emoji = "❌" + status_color = "red" + + self.console.print( + f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} " + f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]" + ) + # 错误率检查 if stats.processed_count > 0: error_rate = stats.error_count / stats.processed_count if error_rate > 0.1: # 错误率超过10% - logger.warning( - f"集合 {config.mongo_collection} 错误率较高: {error_rate:.1%} " - f"({stats.error_count}/{stats.processed_count})" + self.console.print( + f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} " + f"({stats.error_count}/{stats.processed_count})[/red]" ) finally: @@ -571,52 +749,131 @@ class MongoToSQLiteMigrator: return all_stats def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]): - """打印迁移汇总信息""" - logger.info("\n" + "="*60) - logger.info("数据迁移汇总报告") - logger.info("="*60) - + """使用Rich打印美观的迁移汇总信息""" + # 计算总体统计 total_processed = sum(stats.processed_count for stats in all_stats.values()) total_success = sum(stats.success_count for stats in all_stats.values()) total_errors = sum(stats.error_count for stats in all_stats.values()) total_skipped = sum(stats.skipped_count for stats in all_stats.values()) total_duplicates = sum(stats.duplicate_count for stats in all_stats.values()) + total_validation_errors = sum(stats.validation_errors for stats in all_stats.values()) + total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values()) - # 表头 - logger.info(f"{'集合名称':<20} | {'处理':<6} | {'成功':<6} | {'错误':<6} | {'跳过':<6} | {'重复':<6} | {'成功率':<8}") - logger.info("-" * 75) + # 计算总耗时 + total_duration_seconds = 0 + for stats in all_stats.values(): + if stats.start_time and stats.end_time: + duration = stats.end_time - stats.start_time + total_duration_seconds += duration.total_seconds() + + # 创建详细统计表格 + table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta") + table.add_column("集合名称", style="cyan", width=20) + table.add_column("文档总数", justify="right", style="blue") + table.add_column("处理数量", justify="right", style="green") + table.add_column("成功数量", justify="right", style="green") + table.add_column("错误数量", justify="right", style="red") + table.add_column("跳过数量", justify="right", style="yellow") + table.add_column("重复数量", justify="right", style="bright_yellow") + table.add_column("验证错误", justify="right", style="red") + table.add_column("批次数", justify="right", style="purple") + table.add_column("成功率", justify="right", style="bright_green") + table.add_column("耗时(秒)", justify="right", style="blue") for collection_name, stats in all_stats.items(): success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0 - logger.info( - f"{collection_name:<20} | " - f"{stats.processed_count:<6} | " - f"{stats.success_count:<6} | " - f"{stats.error_count:<6} | " - f"{stats.skipped_count:<6} | " - f"{stats.duplicate_count:<6} | " - f"{success_rate:<7.1f}%" + duration = 0 + if stats.start_time and stats.end_time: + duration = (stats.end_time - stats.start_time).total_seconds() + + # 根据成功率设置颜色 + if success_rate >= 95: + success_rate_style = "[bright_green]" + elif success_rate >= 80: + success_rate_style = "[yellow]" + else: + success_rate_style = "[red]" + + table.add_row( + collection_name, + str(stats.total_documents), + str(stats.processed_count), + str(stats.success_count), + f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0", + f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0", + f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0", + f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0", + str(stats.batch_insert_count), + f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}", + f"{duration:.2f}" ) - logger.info("-" * 75) + # 添加总计行 total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0 - logger.info( - f"{'总计':<20} | " - f"{total_processed:<6} | " - f"{total_success:<6} | " - f"{total_errors:<6} | " - f"{total_skipped:<6} | " - f"{total_duplicates:<6} | " - f"{total_success_rate:<7.1f}%" + if total_success_rate >= 95: + total_rate_style = "[bright_green]" + elif total_success_rate >= 80: + total_rate_style = "[yellow]" + else: + total_rate_style = "[red]" + + table.add_section() + table.add_row( + "[bold]总计[/bold]", + f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]", + f"[bold]{total_processed}[/bold]", + f"[bold]{total_success}[/bold]", + f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]", + f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]", + f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" if total_duplicates > 0 else "[bold]0[/bold]", + f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]", + f"[bold]{total_batch_inserts}[/bold]", + f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]", + f"[bold]{total_duration_seconds:.2f}[/bold]" ) + self.console.print(table) + + # 创建状态面板 + status_items = [] if total_errors > 0: - logger.warning(f"\n⚠️ 存在 {total_errors} 个错误,请检查日志详情") + status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]") + + if total_validation_errors > 0: + status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]") if total_duplicates > 0: - logger.info(f"ℹ️ 跳过了 {total_duplicates} 个重复记录") + status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]") - logger.info("="*60) + if total_success_rate >= 95: + status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]") + elif total_success_rate >= 80: + status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]") + else: + status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]") + + if status_items: + status_panel = Panel( + "\n".join(status_items), + title="[bold yellow]迁移状态总结[/bold yellow]", + border_style="yellow" + ) + self.console.print(status_panel) + + # 性能统计面板 + avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0 + performance_info = ( + f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n" + f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n" + f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作" + ) + + performance_panel = Panel( + performance_info, + title="[bold green]性能统计[/bold green]", + border_style="green" + ) + self.console.print(performance_panel) def add_migration_config(self, config: MigrationConfig): """添加新的迁移配置""" diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index ccf78964..bd264637 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -44,9 +44,9 @@ class ChatStreams(BaseModel): # platform: "qq" # group_id: "941657197" # group_name: "测试" - group_platform = TextField() - group_id = TextField() - group_name = TextField() + group_platform = TextField(null=True) # 群聊信息可能不存在 + group_id = TextField(null=True) + group_name = TextField(null=True) # last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位) last_active_time = DoubleField() From 847bd23a62e602b98b708dce293da62ad9873f24 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 28 May 2025 12:51:51 +0000 Subject: [PATCH 10/13] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/mongodb_to_sqlite.py | 443 +++++++++++++++++------------------ 1 file changed, 220 insertions(+), 223 deletions(-) diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index 609906fa..0f292fef 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -1,6 +1,7 @@ import os import json import sys # 新增系统模块导入 + # import time import pickle from pathlib import Path @@ -16,13 +17,13 @@ from peewee import Model, Field, IntegrityError # Rich 进度条和显示组件 from rich.console import Console from rich.progress import ( - Progress, - TextColumn, - BarColumn, - TaskProgressColumn, + Progress, + TextColumn, + BarColumn, + TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn, - SpinnerColumn + SpinnerColumn, ) from rich.table import Table from rich.panel import Panel @@ -30,17 +31,29 @@ from rich.panel import Panel from src.common.database.database import db from src.common.database.database_model import ( - ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions, - PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + PersonInfo, + Knowledges, + ThinkingLog, + GraphNodes, + GraphEdges, ) from src.common.logger_manager import get_logger logger = get_logger("mongodb_to_sqlite") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + + @dataclass class MigrationConfig: """迁移配置类""" + mongo_collection: str target_model: Type[Model] field_mapping: Dict[str, str] @@ -56,6 +69,7 @@ class MigrationConfig: @dataclass class MigrationCheckpoint: """迁移断点数据""" + collection_name: str processed_count: int last_processed_id: Any @@ -66,6 +80,7 @@ class MigrationCheckpoint: @dataclass class MigrationStats: """迁移统计信息""" + total_documents: int = 0 processed_count: int = 0 success_count: int = 0 @@ -77,17 +92,14 @@ class MigrationStats: errors: List[Dict[str, Any]] = field(default_factory=list) start_time: Optional[datetime] = None end_time: Optional[datetime] = None - + def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None): """添加错误记录""" - self.errors.append({ - 'doc_id': str(doc_id), - 'error': error, - 'timestamp': datetime.now().isoformat(), - 'doc_data': doc_data - }) + self.errors.append( + {"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data} + ) self.error_count += 1 - + def add_validation_error(self, doc_id: Any, field: str, error: str): """添加验证错误""" self.add_error(doc_id, f"验证失败 - {field}: {error}") @@ -96,81 +108,81 @@ class MigrationStats: class MongoToSQLiteMigrator: """MongoDB到SQLite数据迁移器 - 使用Peewee ORM""" - + def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None): self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot") self.mongo_uri = mongo_uri or self._build_mongo_uri() self.mongo_client: Optional[MongoClient] = None self.mongo_db = None - + # 迁移配置 self.migration_configs = self._initialize_migration_configs() - + # 进度条控制台 self.console = Console() - # 检查点目录 + # 检查点目录 self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints")) self.checkpoint_dir.mkdir(exist_ok=True) - + # 验证规则已禁用 self.validation_rules = self._initialize_validation_rules() - + def _build_mongo_uri(self) -> str: """构建MongoDB连接URI""" if mongo_uri := os.getenv("MONGODB_URI"): return mongo_uri - - user = os.getenv('MONGODB_USER') - password = os.getenv('MONGODB_PASS') - host = os.getenv('MONGODB_HOST', 'localhost') - port = os.getenv('MONGODB_PORT', '27017') - auth_source = os.getenv('MONGODB_AUTH_SOURCE', 'admin') - + + user = os.getenv("MONGODB_USER") + password = os.getenv("MONGODB_PASS") + host = os.getenv("MONGODB_HOST", "localhost") + port = os.getenv("MONGODB_PORT", "27017") + auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin") + if user and password: return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}" else: return f"mongodb://{host}:{port}/{self.database_name}" - + def _initialize_migration_configs(self) -> List[MigrationConfig]: """初始化迁移配置""" - return [ # 表情包迁移配置 + return [ # 表情包迁移配置 MigrationConfig( mongo_collection="emoji", target_model=Emoji, field_mapping={ "full_path": "full_path", - "format": "format", + "format": "format", "hash": "emoji_hash", "description": "description", "emotion": "emotion", "usage_count": "usage_count", - "last_used_time": "last_used_time" + "last_used_time": "last_used_time", # record_time字段将在转换时自动设置为当前时间 }, enable_validation=False, # 禁用数据验证 - unique_fields=["full_path", "emoji_hash"] + unique_fields=["full_path", "emoji_hash"], ), - # 聊天流迁移配置 + # 聊天流迁移配置 MigrationConfig( mongo_collection="chat_streams", target_model=ChatStreams, field_mapping={ "stream_id": "stream_id", "create_time": "create_time", - "group_info.platform": "group_platform",# 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。 + "group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。 "group_info.group_id": "group_id", # 同上 - "group_info.group_name": "group_name", # 同上 + "group_info.group_name": "group_name", # 同上 "last_active_time": "last_active_time", "platform": "platform", "user_info.platform": "user_platform", "user_info.user_id": "user_id", "user_info.user_nickname": "user_nickname", - "user_info.user_cardname": "user_cardname" + "user_info.user_cardname": "user_cardname", }, enable_validation=False, # 禁用数据验证 - unique_fields=["stream_id"] + unique_fields=["stream_id"], ), - # LLM使用记录迁移配置 + # LLM使用记录迁移配置 MigrationConfig( mongo_collection="llm_usage", target_model=LLMUsage, @@ -184,12 +196,12 @@ class MongoToSQLiteMigrator: "total_tokens": "total_tokens", "cost": "cost", "status": "status", - "timestamp": "timestamp" + "timestamp": "timestamp", }, enable_validation=False, # 禁用数据验证 - unique_fields=["user_id", "timestamp"] # 组合唯一性 + unique_fields=["user_id", "timestamp"], # 组合唯一性 ), - # 消息迁移配置 + # 消息迁移配置 MigrationConfig( mongo_collection="messages", target_model=Messages, @@ -214,12 +226,11 @@ class MongoToSQLiteMigrator: "user_info.user_cardname": "user_cardname", "processed_plain_text": "processed_plain_text", "detailed_plain_text": "detailed_plain_text", - "memorized_times": "memorized_times" + "memorized_times": "memorized_times", }, enable_validation=False, # 禁用数据验证 - unique_fields=["message_id"] + unique_fields=["message_id"], ), - # 图片迁移配置 MigrationConfig( mongo_collection="images", @@ -229,11 +240,10 @@ class MongoToSQLiteMigrator: "description": "description", "path": "path", "timestamp": "timestamp", - "type": "type" + "type": "type", }, - unique_fields=["path"] + unique_fields=["path"], ), - # 图片描述迁移配置 MigrationConfig( mongo_collection="image_descriptions", @@ -242,10 +252,10 @@ class MongoToSQLiteMigrator: "type": "type", "hash": "image_description_hash", "description": "description", - "timestamp": "timestamp" - }, unique_fields=["image_description_hash", "type"] + "timestamp": "timestamp", + }, + unique_fields=["image_description_hash", "type"], ), - # 个人信息迁移配置 MigrationConfig( mongo_collection="person_info", @@ -260,22 +270,17 @@ class MongoToSQLiteMigrator: "relationship_value": "relationship_value", "konw_time": "know_time", "msg_interval": "msg_interval", - "msg_interval_list": "msg_interval_list" + "msg_interval_list": "msg_interval_list", }, - unique_fields=["person_id"] + unique_fields=["person_id"], ), - # 知识库迁移配置 MigrationConfig( mongo_collection="knowledges", target_model=Knowledges, - field_mapping={ - "content": "content", - "embedding": "embedding" - }, - unique_fields=["content"] # 假设内容唯一 + field_mapping={"content": "content", "embedding": "embedding"}, + unique_fields=["content"], # 假设内容唯一 ), - # 思考日志迁移配置 MigrationConfig( mongo_collection="thinking_log", @@ -293,9 +298,8 @@ class MongoToSQLiteMigrator: "heartflow_data": "heartflow_data_json", "reasoning_data": "reasoning_data_json", }, - unique_fields=["chat_id", "trigger_text"] + unique_fields=["chat_id", "trigger_text"], ), - # 图节点迁移配置 MigrationConfig( mongo_collection="graph_data.nodes", @@ -305,11 +309,10 @@ class MongoToSQLiteMigrator: "memory_items": "memory_items", "hash": "hash", "created_time": "created_time", - "last_modified": "last_modified" + "last_modified": "last_modified", }, - unique_fields=["concept"] + unique_fields=["concept"], ), - # 图边迁移配置 MigrationConfig( mongo_collection="graph_data.edges", @@ -320,71 +323,69 @@ class MongoToSQLiteMigrator: "strength": "strength", "hash": "hash", "created_time": "created_time", - "last_modified": "last_modified" + "last_modified": "last_modified", }, - unique_fields=["source", "target"] # 组合唯一性 - ) + unique_fields=["source", "target"], # 组合唯一性 + ), ] + def _initialize_validation_rules(self) -> Dict[str, Any]: """数据验证已禁用 - 返回空字典""" return {} - + def connect_mongodb(self) -> bool: """连接到MongoDB""" try: self.mongo_client = MongoClient( - self.mongo_uri, - serverSelectionTimeoutMS=5000, - connectTimeoutMS=10000, - maxPoolSize=10 + self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10 ) - + # 测试连接 - self.mongo_client.admin.command('ping') + self.mongo_client.admin.command("ping") self.mongo_db = self.mongo_client[self.database_name] - + logger.info(f"成功连接到MongoDB: {self.database_name}") return True - + except ConnectionFailure as e: logger.error(f"MongoDB连接失败: {e}") return False except Exception as e: logger.error(f"MongoDB连接异常: {e}") return False - + def disconnect_mongodb(self): """断开MongoDB连接""" if self.mongo_client: self.mongo_client.close() logger.info("MongoDB连接已关闭") - + def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any: """获取嵌套字段的值""" if "." not in field_path: return document.get(field_path) - + parts = field_path.split(".") value = document - + for part in parts: if isinstance(value, dict): value = value.get(part) else: return None - + if value is None: break - + return value - + def _convert_field_value(self, value: Any, target_field: Field) -> Any: """根据目标字段类型转换值""" if value is None: return None - + field_type = target_field.__class__.__name__ - + try: if target_field.name == "record_time" and field_type == "DateTimeField": return datetime.now() @@ -393,31 +394,31 @@ class MongoToSQLiteMigrator: if isinstance(value, (list, dict)): return json.dumps(value, ensure_ascii=False) return str(value) if value is not None else "" - + elif field_type == "IntegerField": if isinstance(value, str): # 处理字符串数字 clean_value = value.strip() - if clean_value.replace('.', '').replace('-', '').isdigit(): + if clean_value.replace(".", "").replace("-", "").isdigit(): return int(float(clean_value)) return 0 return int(value) if value is not None else 0 - + elif field_type in ["FloatField", "DoubleField"]: return float(value) if value is not None else 0.0 - + elif field_type == "BooleanField": if isinstance(value, str): - return value.lower() in ('true', '1', 'yes', 'on') + return value.lower() in ("true", "1", "yes", "on") return bool(value) - + elif field_type == "DateTimeField": if isinstance(value, (int, float)): return datetime.fromtimestamp(value) elif isinstance(value, str): try: # 尝试解析ISO格式日期 - return datetime.fromisoformat(value.replace('Z', '+00:00')) + return datetime.fromisoformat(value.replace("Z", "+00:00")) except ValueError: try: # 尝试解析时间戳字符串 @@ -425,23 +426,23 @@ class MongoToSQLiteMigrator: except ValueError: return datetime.now() return datetime.now() - + return value - + except (ValueError, TypeError) as e: logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}") return self._get_default_value_for_field(target_field) - + def _get_default_value_for_field(self, field: Field) -> Any: """获取字段的默认值""" field_type = field.__class__.__name__ - - if hasattr(field, 'default') and field.default is not None: + + if hasattr(field, "default") and field.default is not None: return field.default - + if field.null: return None - + # 根据字段类型返回默认值 if field_type in ["CharField", "TextField"]: return "" @@ -450,56 +451,57 @@ class MongoToSQLiteMigrator: elif field_type in ["FloatField", "DoubleField"]: return 0.0 elif field_type == "BooleanField": - return False + return False elif field_type == "DateTimeField": return datetime.now() - + return None + def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: """数据验证已禁用 - 始终返回True""" return True - + def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any): """保存迁移断点""" checkpoint = MigrationCheckpoint( collection_name=collection_name, processed_count=processed_count, last_processed_id=last_id, - timestamp=datetime.now() + timestamp=datetime.now(), ) - + checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" try: - with open(checkpoint_file, 'wb') as f: + with open(checkpoint_file, "wb") as f: pickle.dump(checkpoint, f) except Exception as e: logger.warning(f"保存断点失败: {e}") - + def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]: """加载迁移断点""" checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" if not checkpoint_file.exists(): return None - + try: - with open(checkpoint_file, 'rb') as f: + with open(checkpoint_file, "rb") as f: return pickle.load(f) except Exception as e: logger.warning(f"加载断点失败: {e}") return None - + def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int: """批量插入数据""" if not data_list: return 0 - + success_count = 0 try: with db.atomic(): # 分批插入,避免SQL语句过长 batch_size = 100 for i in range(0, len(data_list), batch_size): - batch = data_list[i:i + batch_size] + batch = data_list[i : i + batch_size] model.insert_many(batch).execute() success_count += len(batch) except Exception as e: @@ -511,27 +513,28 @@ class MongoToSQLiteMigrator: success_count += 1 except Exception: pass # 忽略单个插入失败 - + return success_count - - def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any], - unique_fields: List[str]) -> bool: + + def _check_duplicate_by_unique_fields( + self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str] + ) -> bool: """根据唯一字段检查重复""" if not unique_fields: return False - + try: query = model.select() for field_name in unique_fields: if field_name in data and data[field_name] is not None: field_obj = getattr(model, field_name) query = query.where(field_obj == data[field_name]) - + return query.exists() except Exception as e: logger.debug(f"重复检查失败: {e}") return False - + def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]: """使用ORM创建模型实例""" try: @@ -542,11 +545,11 @@ class MongoToSQLiteMigrator: valid_data[field_name] = value else: logger.debug(f"跳过未知字段: {field_name}") - + # 创建实例 instance = model.create(**valid_data) return instance - + except IntegrityError as e: # 处理唯一约束冲突等完整性错误 logger.debug(f"完整性约束冲突: {e}") @@ -554,37 +557,38 @@ class MongoToSQLiteMigrator: except Exception as e: logger.error(f"创建模型实例失败: {e}") return None + def migrate_collection(self, config: MigrationConfig) -> MigrationStats: """迁移单个集合 - 使用优化的批量插入和进度条""" stats = MigrationStats() stats.start_time = datetime.now() - + # 检查是否有断点 checkpoint = self._load_checkpoint(config.mongo_collection) start_from_id = checkpoint.last_processed_id if checkpoint else None if checkpoint: stats.processed_count = checkpoint.processed_count logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录") - + logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}") - + try: # 获取MongoDB集合 mongo_collection = self.mongo_db[config.mongo_collection] - + # 构建查询条件(用于断点恢复) query = {} if start_from_id: query = {"_id": {"$gt": start_from_id}} - + stats.total_documents = mongo_collection.count_documents(query) - + if stats.total_documents == 0: logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移") return stats - + logger.info(f"待迁移文档数量: {stats.total_documents}") - + # 创建Rich进度条 with Progress( SpinnerColumn(), @@ -594,40 +598,36 @@ class MongoToSQLiteMigrator: TimeElapsedColumn(), TimeRemainingColumn(), console=self.console, - refresh_per_second=10 + refresh_per_second=10, ) as progress: - - task = progress.add_task( - f"迁移 {config.mongo_collection}", - total=stats.total_documents - ) - # 批量处理数据 + task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents) + # 批量处理数据 batch_data = [] batch_count = 0 last_processed_id = None - + for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size): try: - doc_id = mongo_doc.get('_id', 'unknown') + doc_id = mongo_doc.get("_id", "unknown") last_processed_id = doc_id - + # 构建目标数据 target_data = {} for mongo_field, sqlite_field in config.field_mapping.items(): value = self._get_nested_value(mongo_doc, mongo_field) - + # 获取目标字段对象并转换类型 if hasattr(config.target_model, sqlite_field): field_obj = getattr(config.target_model, sqlite_field) converted_value = self._convert_field_value(value, field_obj) target_data[sqlite_field] = converted_value - + # 数据验证已禁用 # if config.enable_validation: # if not self._validate_data(config.mongo_collection, target_data, doc_id, stats): # stats.skipped_count += 1 # continue - + # 重复检查 if config.skip_duplicates and self._check_duplicate_by_unique_fields( config.target_model, target_data, config.unique_fields @@ -636,88 +636,93 @@ class MongoToSQLiteMigrator: stats.skipped_count += 1 logger.debug(f"跳过重复记录: {doc_id}") continue - + # 添加到批量数据 batch_data.append(target_data) stats.processed_count += 1 - + # 执行批量插入 if len(batch_data) >= config.batch_size: success_count = self._batch_insert(config.target_model, batch_data) stats.success_count += success_count stats.batch_insert_count += 1 - + # 保存断点 self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id) - + batch_data.clear() batch_count += 1 - + # 更新进度条 progress.update(task, advance=config.batch_size) - + except Exception as e: - doc_id = mongo_doc.get('_id', 'unknown') + doc_id = mongo_doc.get("_id", "unknown") stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) logger.error(f"处理文档失败 (ID: {doc_id}): {e}") - + # 处理剩余的批量数据 if batch_data: success_count = self._batch_insert(config.target_model, batch_data) stats.success_count += success_count stats.batch_insert_count += 1 progress.update(task, advance=len(batch_data)) - + # 完成进度条 progress.update(task, completed=stats.total_documents) - + stats.end_time = datetime.now() duration = stats.end_time - stats.start_time - + logger.info( f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n" f"总计: {stats.total_documents}, 成功: {stats.success_count}, " f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n" f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}" ) - + # 清理断点文件 checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl" if checkpoint_file.exists(): checkpoint_file.unlink() - + except Exception as e: logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}") stats.add_error("collection_error", str(e)) - + return stats + def migrate_all(self) -> Dict[str, MigrationStats]: """执行所有迁移任务""" logger.info("开始执行数据库迁移...") - + if not self.connect_mongodb(): logger.error("无法连接到MongoDB,迁移终止") return {} - + all_stats = {} - + try: # 创建总体进度表格 total_collections = len(self.migration_configs) - self.console.print(Panel( - f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n" - f"[yellow]总集合数: {total_collections}[/yellow]", - title="迁移开始", - expand=False - )) + self.console.print( + Panel( + f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n" + f"[yellow]总集合数: {total_collections}[/yellow]", + title="迁移开始", + expand=False, + ) + ) for idx, config in enumerate(self.migration_configs, 1): - self.console.print(f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]") + self.console.print( + f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]" + ) stats = self.migrate_collection(config) all_stats[config.mongo_collection] = stats - + # 显示单个集合的快速统计 if stats.processed_count > 0: - success_rate = (stats.success_count / stats.processed_count * 100) + success_rate = stats.success_count / stats.processed_count * 100 if success_rate >= 95: status_emoji = "✅" status_color = "bright_green" @@ -727,12 +732,12 @@ class MongoToSQLiteMigrator: else: status_emoji = "❌" status_color = "red" - + self.console.print( f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} " f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]" ) - + # 错误率检查 if stats.processed_count > 0: error_rate = stats.error_count / stats.processed_count @@ -741,13 +746,13 @@ class MongoToSQLiteMigrator: f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} " f"({stats.error_count}/{stats.processed_count})[/red]" ) - + finally: self.disconnect_mongodb() - + self._print_migration_summary(all_stats) return all_stats - + def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]): """使用Rich打印美观的迁移汇总信息""" # 计算总体统计 @@ -758,14 +763,14 @@ class MongoToSQLiteMigrator: total_duplicates = sum(stats.duplicate_count for stats in all_stats.values()) total_validation_errors = sum(stats.validation_errors for stats in all_stats.values()) total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values()) - + # 计算总耗时 total_duration_seconds = 0 for stats in all_stats.values(): if stats.start_time and stats.end_time: duration = stats.end_time - stats.start_time total_duration_seconds += duration.total_seconds() - + # 创建详细统计表格 table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta") table.add_column("集合名称", style="cyan", width=20) @@ -779,13 +784,13 @@ class MongoToSQLiteMigrator: table.add_column("批次数", justify="right", style="purple") table.add_column("成功率", justify="right", style="bright_green") table.add_column("耗时(秒)", justify="right", style="blue") - + for collection_name, stats in all_stats.items(): success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0 duration = 0 if stats.start_time and stats.end_time: duration = (stats.end_time - stats.start_time).total_seconds() - + # 根据成功率设置颜色 if success_rate >= 95: success_rate_style = "[bright_green]" @@ -793,7 +798,7 @@ class MongoToSQLiteMigrator: success_rate_style = "[yellow]" else: success_rate_style = "[red]" - + table.add_row( collection_name, str(stats.total_documents), @@ -805,9 +810,9 @@ class MongoToSQLiteMigrator: f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0", str(stats.batch_insert_count), f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}", - f"{duration:.2f}" + f"{duration:.2f}", ) - + # 添加总计行 total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0 if total_success_rate >= 95: @@ -816,7 +821,7 @@ class MongoToSQLiteMigrator: total_rate_style = "[yellow]" else: total_rate_style = "[red]" - + table.add_section() table.add_row( "[bold]总计[/bold]", @@ -825,41 +830,41 @@ class MongoToSQLiteMigrator: f"[bold]{total_success}[/bold]", f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]", f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]", - f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" if total_duplicates > 0 else "[bold]0[/bold]", + f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" + if total_duplicates > 0 + else "[bold]0[/bold]", f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]", f"[bold]{total_batch_inserts}[/bold]", f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]", - f"[bold]{total_duration_seconds:.2f}[/bold]" + f"[bold]{total_duration_seconds:.2f}[/bold]", ) - + self.console.print(table) - + # 创建状态面板 status_items = [] if total_errors > 0: status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]") - + if total_validation_errors > 0: status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]") - + if total_duplicates > 0: status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]") - + if total_success_rate >= 95: status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]") elif total_success_rate >= 80: status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]") else: status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]") - + if status_items: status_panel = Panel( - "\n".join(status_items), - title="[bold yellow]迁移状态总结[/bold yellow]", - border_style="yellow" + "\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow" ) self.console.print(status_panel) - + # 性能统计面板 avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0 performance_info = ( @@ -867,60 +872,52 @@ class MongoToSQLiteMigrator: f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n" f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作" ) - - performance_panel = Panel( - performance_info, - title="[bold green]性能统计[/bold green]", - border_style="green" - ) + + performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green") self.console.print(performance_panel) - + def add_migration_config(self, config: MigrationConfig): """添加新的迁移配置""" self.migration_configs.append(config) - + def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]: """迁移单个指定的集合""" config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None) if not config: logger.error(f"未找到集合 {collection_name} 的迁移配置") return None - + if not self.connect_mongodb(): logger.error("无法连接到MongoDB") return None - + try: stats = self.migrate_collection(config) self._print_migration_summary({collection_name: stats}) return stats finally: self.disconnect_mongodb() - + def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str): """导出错误报告""" error_report = { - 'timestamp': datetime.now().isoformat(), - 'summary': { + "timestamp": datetime.now().isoformat(), + "summary": { collection: { - 'total': stats.total_documents, - 'processed': stats.processed_count, - 'success': stats.success_count, - 'errors': stats.error_count, - 'skipped': stats.skipped_count, - 'duplicates': stats.duplicate_count + "total": stats.total_documents, + "processed": stats.processed_count, + "success": stats.success_count, + "errors": stats.error_count, + "skipped": stats.skipped_count, + "duplicates": stats.duplicate_count, } for collection, stats in all_stats.items() }, - 'errors': { - collection: stats.errors - for collection, stats in all_stats.items() - if stats.errors - } + "errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors}, } - + try: - with open(filepath, 'w', encoding='utf-8') as f: + with open(filepath, "w", encoding="utf-8") as f: json.dump(error_report, f, ensure_ascii=False, indent=2) logger.info(f"错误报告已导出到: {filepath}") except Exception as e: @@ -930,17 +927,17 @@ class MongoToSQLiteMigrator: def main(): """主程序入口""" migrator = MongoToSQLiteMigrator() - + # 执行迁移 migration_results = migrator.migrate_all() - + # 导出错误报告(如果有错误) if any(stats.error_count > 0 for stats in migration_results.values()): error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" migrator.export_error_report(migration_results, error_report_path) - + logger.info("数据迁移完成!") if __name__ == "__main__": - main() \ No newline at end of file + main() From f528d83826f6000f09a78a31af8afed8f668918c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 28 May 2025 20:52:15 +0800 Subject: [PATCH 11/13] =?UTF-8?q?=E9=87=8D=E5=86=99=20mongodb=5Fto=5Fsqlit?= =?UTF-8?q?e.bat=20=E8=84=9A=E6=9C=AC=EF=BC=8C=E5=A2=9E=E5=8A=A0=E7=8E=AF?= =?UTF-8?q?=E5=A2=83=E9=80=89=E6=8B=A9=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20Conda=20=E5=92=8C=20venv=20=E6=BF=80=E6=B4=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mongodb_to_sqlite .bat => mongodb_to_sqlite.bat | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mongodb_to_sqlite .bat => mongodb_to_sqlite.bat (100%) diff --git a/mongodb_to_sqlite .bat b/mongodb_to_sqlite.bat similarity index 100% rename from mongodb_to_sqlite .bat rename to mongodb_to_sqlite.bat From 9bb2fe2d52777e54658fa7bc06409f48663ebf9c Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 28 May 2025 21:10:09 +0800 Subject: [PATCH 12/13] =?UTF-8?q?feat=EF=BC=9A=E4=B8=BAnoraml=5Fcaht?= =?UTF-8?q?=E5=8A=A0=E5=85=A5=E8=A1=A8=E8=BE=BE=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changelogs/changelog.md | 2 +- src/chat/normal_chat/normal_chat.py | 3 +- src/chat/normal_chat/normal_prompt.py | 99 +++++++++++++++++++++++---- template/bot_config_template.toml | 13 ++-- 4 files changed, 95 insertions(+), 22 deletions(-) diff --git a/changelogs/changelog.md b/changelogs/changelog.md index c4dc8b3b..461bd7bb 100644 --- a/changelogs/changelog.md +++ b/changelogs/changelog.md @@ -36,7 +36,7 @@ - 优化了进入和离开normal_chat的方式 **新增表达方式学习** -- 在专注模式下,麦麦可以有独特的表达方式 +- 麦麦配置单独表达方式 - 自主学习群聊中的表达方式,更贴近群友 - 可自定义的学习频率和开关 - 根据人设生成额外的表达方式 diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 98478183..dc4da2ea 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -337,7 +337,8 @@ class NormalChat: self.recent_replies = self.recent_replies[-self.max_replies_history :] # 检查是否需要切换到focus模式 - await self._check_switch_to_focus() + if global_config.chat.chat_mode == "auto": + await self._check_switch_to_focus() info_catcher.done_catch() diff --git a/src/chat/normal_chat/normal_prompt.py b/src/chat/normal_chat/normal_prompt.py index 54624a02..8906106a 100644 --- a/src/chat/normal_chat/normal_prompt.py +++ b/src/chat/normal_chat/normal_prompt.py @@ -10,6 +10,7 @@ from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager +from src.chat.focus_chat.expressors.exprssion_learner import expression_learner import random @@ -24,6 +25,11 @@ def init_prompt(): Prompt( """ +你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: +{style_habbits} +请你根据情景使用以下句法,不要盲目使用,不要生硬使用,而是结合到表达中: +{grammar_habbits} + {memory_prompt} {relation_prompt} {prompt_info} @@ -31,7 +37,7 @@ def init_prompt(): {chat_talking_prompt} 现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言或者回复这条消息。\n 你的网名叫{bot_name},有人也叫你{bot_other_names},{prompt_personality}。 -你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1}, +你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},请你给出回复 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger} 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。 @@ -49,6 +55,11 @@ def init_prompt(): Prompt( """ +你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: +{style_habbits} +请你根据情景使用以下句法,不要盲目使用,不要生硬使用,而是结合到表达中: +{grammar_habbits} + {memory_prompt} {relation_prompt} {prompt_info} @@ -58,7 +69,7 @@ def init_prompt(): 现在 {sender_name} 说的: {message_txt} 引起了你的注意,你想要回复这条消息。 你的网名叫{bot_name},有人也叫你{bot_other_names},{prompt_personality}。 -你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,{mood_prompt},{reply_style1}, +你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,{mood_prompt},请你给出回复 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger} 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号等),只输出回复内容。 @@ -103,15 +114,42 @@ class PromptBuilder: relation_prompt += await relationship_manager.build_relationship_info(person) mood_prompt = mood_manager.get_mood_prompt() - reply_styles1 = [ - ("然后给出日常且口语化的回复,平淡一些", 0.4), - ("给出非常简短的回复", 0.4), - ("给出缺失主语的回复", 0.15), - ("给出带有语病的回复", 0.05), - ] - reply_style1_chosen = random.choices( - [style[0] for style in reply_styles1], weights=[style[1] for style in reply_styles1], k=1 - )[0] + + + ( + learnt_style_expressions, + learnt_grammar_expressions, + personality_expressions, + ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id) + + style_habbits = [] + grammar_habbits = [] + # 1. learnt_expressions加权随机选2条 + if learnt_style_expressions: + weights = [expr["count"] for expr in learnt_style_expressions] + selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 2) + for expr in selected_learnt: + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + # 2. learnt_grammar_expressions加权随机选2条 + if learnt_grammar_expressions: + weights = [expr["count"] for expr in learnt_grammar_expressions] + selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 2) + for expr in selected_learnt: + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + # 3. personality_expressions随机选1条 + if personality_expressions: + expr = random.choice(personality_expressions) + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + + style_habbits_str = "\n".join(style_habbits) + grammar_habbits_str = "\n".join(grammar_habbits) + + + + reply_styles2 = [ ("不要回复的太有条理,可以有个性", 0.6), ("不要回复的太有条理,可以复读", 0.15), @@ -208,7 +246,8 @@ class PromptBuilder: bot_other_names="/".join(global_config.bot.alias_names), prompt_personality=prompt_personality, mood_prompt=mood_prompt, - reply_style1=reply_style1_chosen, + style_habbits=style_habbits_str, + grammar_habbits=grammar_habbits_str, reply_style2=reply_style2_chosen, keywords_reaction_prompt=keywords_reaction_prompt, prompt_ger=prompt_ger, @@ -231,7 +270,8 @@ class PromptBuilder: bot_other_names="/".join(global_config.bot.alias_names), prompt_personality=prompt_personality, mood_prompt=mood_prompt, - reply_style1=reply_style1_chosen, + style_habbits=style_habbits_str, + grammar_habbits=grammar_habbits_str, reply_style2=reply_style2_chosen, keywords_reaction_prompt=keywords_reaction_prompt, prompt_ger=prompt_ger, @@ -266,6 +306,39 @@ class PromptBuilder: except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") return "未检索到知识" + +def weighted_sample_no_replacement(items, weights, k) -> list: + """ + 加权且不放回地随机抽取k个元素。 + + 参数: + items: 待抽取的元素列表 + weights: 每个元素对应的权重(与items等长,且为正数) + k: 需要抽取的元素个数 + 返回: + selected: 按权重加权且不重复抽取的k个元素组成的列表 + + 如果 items 中的元素不足 k 个,就只会返回所有可用的元素 + + 实现思路: + 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。 + 这样保证了: + 1. count越大被选中概率越高 + 2. 不会重复选中同一个元素 + """ + selected = [] + pool = list(zip(items, weights)) + for _ in range(min(k, len(pool))): + total = sum(w for _, w in pool) + r = random.uniform(0, total) + upto = 0 + for idx, (item, weight) in enumerate(pool): + upto += weight + if upto >= r: + selected.append(item) + pool.pop(idx) + break + return selected init_prompt() diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 2d4eccb9..435abddf 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -38,6 +38,12 @@ identity_detail = [ # 可以描述外貌,性别,身高,职业,属性等等描述 # 条数任意,不能为0 +[expression] +# 表达方式 +expression_style = "描述麦麦说话的表达风格,表达习惯" +enable_expression_learning = true # 是否启用表达学习,麦麦会学习人类说话风格 +learning_interval = 600 # 学习间隔 单位秒 + [relationship] give_name = true # 麦麦是否给其他人取名,关闭后无法使用禁言功能 @@ -97,13 +103,6 @@ self_identify_processor = true # 是否启用自我识别处理器 tool_use_processor = false # 是否启用工具使用处理器 working_memory_processor = false # 是否启用工作记忆处理器 -[expression] -# 表达方式 -expression_style = "描述麦麦说话的表达风格,表达习惯" -enable_expression_learning = true # 是否启用表达学习 -learning_interval = 600 # 学习间隔 单位秒 - - [emoji] max_reg_num = 40 # 表情包最大注册数量 do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包 From f57903ff7c95b54a80be5a46c2b52a3fb16ce61b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 28 May 2025 13:10:24 +0000 Subject: [PATCH 13/13] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/normal_chat/normal_prompt.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/chat/normal_chat/normal_prompt.py b/src/chat/normal_chat/normal_prompt.py index 8906106a..9618987a 100644 --- a/src/chat/normal_chat/normal_prompt.py +++ b/src/chat/normal_chat/normal_prompt.py @@ -114,14 +114,13 @@ class PromptBuilder: relation_prompt += await relationship_manager.build_relationship_info(person) mood_prompt = mood_manager.get_mood_prompt() - - + ( learnt_style_expressions, learnt_grammar_expressions, personality_expressions, ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id) - + style_habbits = [] grammar_habbits = [] # 1. learnt_expressions加权随机选2条 @@ -146,10 +145,7 @@ class PromptBuilder: style_habbits_str = "\n".join(style_habbits) grammar_habbits_str = "\n".join(grammar_habbits) - - - - + reply_styles2 = [ ("不要回复的太有条理,可以有个性", 0.6), ("不要回复的太有条理,可以复读", 0.15), @@ -306,7 +302,8 @@ class PromptBuilder: except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") return "未检索到知识" - + + def weighted_sample_no_replacement(items, weights, k) -> list: """ 加权且不放回地随机抽取k个元素。