From 26339453fef00bbb37f47f98f663726d8da4e25d Mon Sep 17 00:00:00 2001 From: Night-stars-1 Date: Thu, 24 Apr 2025 22:50:50 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=BB=9F=E4=B8=80=E7=AE=A1=E7=90=86?= =?UTF-8?q?=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=EF=BC=8C=E9=81=BF=E5=85=8D?= =?UTF-8?q?=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=E6=9C=AA=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E5=89=8D=E4=BD=BF=E7=94=A8=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F?= =?UTF-8?q?=E4=BB=A5=E5=8F=8A=E9=87=8D=E5=A4=8D=E5=AF=BC=E5=85=A5.env?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + src/api/config_api.py | 8 --- src/common/__init__.py | 5 ++ src/common/common.py | 7 ++ src/common/database.py | 17 +++-- src/common/logger.py | 22 +++---- src/common/server.py | 18 +++++- src/env/__init__.py | 5 ++ src/env/env.py | 64 +++++++++++++++++++ src/individuality/offline_llm.py | 6 +- src/individuality/per_bf_gen.py | 10 --- .../memory_system/manually_alter_memory.py | 11 ---- src/plugins/memory_system/offline_llm.py | 6 +- src/plugins/message/api.py | 3 +- src/plugins/models/utils_model.py | 6 +- src/plugins/zhishi/knowledge_library.py | 10 +-- 16 files changed, 128 insertions(+), 72 deletions(-) create mode 100644 src/common/common.py create mode 100644 src/env/__init__.py create mode 100644 src/env/env.py diff --git a/.gitignore b/.gitignore index 1c3d7bd1..6a6bb683 100644 --- a/.gitignore +++ b/.gitignore @@ -295,3 +295,5 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk + +!src/env \ No newline at end of file diff --git a/src/api/config_api.py b/src/api/config_api.py index e3934617..44883c87 100644 --- a/src/api/config_api.py +++ b/src/api/config_api.py @@ -145,11 +145,3 @@ class BotConfig: api_urls: Dict[str, str] # API URLs - -@strawberry.type -class EnvConfig: - pass - - @strawberry.field - def get_env(self) -> str: - return "env" diff --git a/src/common/__init__.py b/src/common/__init__.py index 497b4a41..94aca859 100644 --- a/src/common/__init__.py +++ b/src/common/__init__.py @@ -1 +1,6 @@ # 这个文件可以为空,但必须存在 +from .common import * + +__all__ = [ + "BASE_PATH" +] diff --git a/src/common/common.py b/src/common/common.py new file mode 100644 index 00000000..8701f551 --- /dev/null +++ b/src/common/common.py @@ -0,0 +1,7 @@ +from pathlib import Path + +BASE_PATH = Path(__file__).parent.parent.parent +"""项目根目录""" + +TEMPLATE_PATH = BASE_PATH / "template" +"""模板文件夹""" \ No newline at end of file diff --git a/src/common/database.py b/src/common/database.py index ee0ead0b..d35376d3 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,19 +1,18 @@ -import os from pymongo import MongoClient from pymongo.database import Database - +from src.env import env _client = None _db = None def __create_database_instance(): - uri = os.getenv("MONGODB_URI") - host = os.getenv("MONGODB_HOST", "127.0.0.1") - port = int(os.getenv("MONGODB_PORT", "27017")) + uri = env.getenv("MONGODB_URI") + host = env.getenv("MONGODB_HOST", "127.0.0.1") + port = int( env.getenv("MONGODB_PORT", "27017")) # db_name 变量在创建连接时不需要,在获取数据库实例时才使用 - username = os.getenv("MONGODB_USERNAME") - password = os.getenv("MONGODB_PASSWORD") - auth_source = os.getenv("MONGODB_AUTH_SOURCE") + username = env.getenv("MONGODB_USERNAME") + password = env.getenv("MONGODB_PASSWORD") + auth_source = env.getenv("MONGODB_AUTH_SOURCE") if uri: # 支持标准mongodb://和mongodb+srv://连接字符串 @@ -39,7 +38,7 @@ def get_db(): global _client, _db if _client is None: _client = __create_database_instance() - _db = _client[os.getenv("DATABASE_NAME", "MegBot")] + _db = _client[ env.getenv("DATABASE_NAME", "MegBot")] return _db diff --git a/src/common/logger.py b/src/common/logger.py index 8a5b7ffc..068073f3 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,15 +1,11 @@ from loguru import logger from typing import Dict, Optional, Union, List, Tuple import sys -import os from types import ModuleType from pathlib import Path -from dotenv import load_dotenv +from src.env import env # from ..plugins.chat.config import global_config -# 加载 .env 文件 -env_path = Path(__file__).resolve().parent.parent.parent / ".env" -load_dotenv(dotenv_path=env_path) # 保存原生处理器ID default_handler_id = None @@ -32,7 +28,7 @@ _custom_style_handlers: Dict[Tuple[str, str], List[int]] = {} # 记录自定义 current_file_path = Path(__file__).resolve() LOG_ROOT = "logs" -SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower() +SIMPLE_OUTPUT = env.getenv("SIMPLE_OUTPUT", "false").strip().lower() if SIMPLE_OUTPUT == "true": SIMPLE_OUTPUT = True else: @@ -625,7 +621,7 @@ def get_module_logger( # 控制台处理器 console_id = logger.add( sink=sys.stderr, - level=os.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]), + level= env.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]), format=current_config["console_format"], filter=lambda record: record["extra"].get("module") == module_name and "custom_style" not in record["extra"], enqueue=True, @@ -640,7 +636,7 @@ def get_module_logger( file_id = logger.add( sink=str(log_file), - level=os.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]), + level= env.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]), format=current_config["file_format"], rotation=current_config["rotation"], retention=current_config["retention"], @@ -686,7 +682,7 @@ def add_custom_style_handler( try: custom_console_id = logger.add( sink=sys.stderr, - level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level), + level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level), format=console_format, filter=lambda record: record["extra"].get("module") == module_name and record["extra"].get("custom_style") == style_name, @@ -710,7 +706,7 @@ def add_custom_style_handler( # try: # custom_file_id = logger.add( # sink=str(log_file), - # level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level), + # level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level), # format=file_format, # rotation=current_config["rotation"], # retention=current_config["retention"], @@ -753,10 +749,10 @@ def remove_module_logger(module_name: str) -> None: # 添加全局默认处理器(只处理未注册模块的日志--->控制台) -# print(os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS")) +# print( env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS")) DEFAULT_GLOBAL_HANDLER = logger.add( sink=sys.stderr, - level=os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"), + level= env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"), format=( "{time:YYYY-MM-DD HH:mm:ss} | " "{level: <8} | " @@ -775,7 +771,7 @@ other_log_dir.mkdir(parents=True, exist_ok=True) DEFAULT_FILE_HANDLER = logger.add( sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"), - level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"), + level= env.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"), format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}", rotation=DEFAULT_CONFIG["rotation"], retention=DEFAULT_CONFIG["retention"], diff --git a/src/common/server.py b/src/common/server.py index 51799629..b1b284d9 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -2,10 +2,16 @@ from fastapi import FastAPI, APIRouter from typing import Optional from uvicorn import Config, Server as UvicornServer import os +from src.env import env class Server: - def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): + def __init__( + self, + host: Optional[str] = None, + port: Optional[int] = None, + app_name: str = "MaiMCore", + ): self.app = FastAPI(title=app_name) self._host: str = "127.0.0.1" self._port: int = 8080 @@ -46,7 +52,13 @@ class Server: async def run(self): """启动服务器""" # 禁用 uvicorn 默认日志和访问日志 - config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False) + config = Config( + app=self.app, + host=self._host, + port=self._port, + log_config=None, + access_log=False, + ) self._server = UvicornServer(config=config) try: await self._server.serve() @@ -71,4 +83,4 @@ class Server: return self.app -global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"])) +global_server = Server(host=env.getenv("HOST"), port=int(env.getenv("PORT"))) diff --git a/src/env/__init__.py b/src/env/__init__.py new file mode 100644 index 00000000..d8c84834 --- /dev/null +++ b/src/env/__init__.py @@ -0,0 +1,5 @@ +from .env import * + +__all__ = [ + "env", +] \ No newline at end of file diff --git a/src/env/env.py b/src/env/env.py new file mode 100644 index 00000000..5308c589 --- /dev/null +++ b/src/env/env.py @@ -0,0 +1,64 @@ +import os +import shutil +from dotenv import load_dotenv + +from common.common import BASE_PATH, TEMPLATE_PATH + + +class EnvConfig: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(EnvConfig, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._initialized = True + self._load_env() + + def _init_env(self): + # 检测.env文件是否存在 + env_file = BASE_PATH / ".env" + if not env_file.exists(): + print("检测到.env文件不存在") + TEMPLATE_ENV_PATH = TEMPLATE_PATH / "template.env" + ENV_PATH = BASE_PATH / ".env" + shutil.copy(TEMPLATE_ENV_PATH, ENV_PATH) + print(f"已从{TEMPLATE_ENV_PATH}复制创建{ENV_PATH},请修改配置后重新启动") + + def _load_env(self): + self._init_env() + env_file = BASE_PATH / ".env" + load_dotenv(env_file) + + # 根据ENVIRONMENT变量加载对应的环境文件 + env_type = os.getenv("ENVIRONMENT", "prod") + if env_type == "dev": + env_file = BASE_PATH / ".env.dev" + elif env_type == "prod": + env_file = BASE_PATH / ".env" + + if env_file.exists(): + load_dotenv(env_file, override=True) + + def get(self, key, default=None): + """获取环境变量""" + return os.getenv(key, default) + + def getenv(self, key, default=None): + """获取环境变量""" + return self.get(key=key, default=default) + + def get_all(self): + """获取所有环境变量""" + return dict(os.environ) + + +# 创建全局实例 +env = EnvConfig() +"""环境变量管理器""" diff --git a/src/individuality/offline_llm.py b/src/individuality/offline_llm.py index 2b5b6dc2..dd7bae84 100644 --- a/src/individuality/offline_llm.py +++ b/src/individuality/offline_llm.py @@ -1,11 +1,11 @@ import asyncio -import os import time from typing import Tuple, Union import aiohttp import requests from src.common.logger import get_module_logger +from src.env import env logger = get_module_logger("offline_llm") @@ -14,8 +14,8 @@ class LLMRequestOff: def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name self.params = kwargs - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = env.getenv("SILICONFLOW_KEY") + self.base_url = env.getenv("SILICONFLOW_BASE_URL") if not self.api_key or not self.base_url: raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") diff --git a/src/individuality/per_bf_gen.py b/src/individuality/per_bf_gen.py index 7e630bdd..73bee757 100644 --- a/src/individuality/per_bf_gen.py +++ b/src/individuality/per_bf_gen.py @@ -1,7 +1,6 @@ from typing import Dict, List import json import os -from dotenv import load_dotenv import sys import toml import random @@ -21,15 +20,6 @@ from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # n from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402 from src.individuality.offline_llm import LLMRequestOff # noqa E402 -# 加载环境变量 -env_path = os.path.join(root_path, ".env") -if os.path.exists(env_path): - print(f"从 {env_path} 加载环境变量") - load_dotenv(env_path) -else: - print(f"未找到环境变量文件: {env_path}") - print("将使用默认配置") - def adapt_scene(scene: str) -> str: personality_core = config["personality"]["personality_core"] diff --git a/src/plugins/memory_system/manually_alter_memory.py b/src/plugins/memory_system/manually_alter_memory.py index 1452d3d5..5e33d762 100644 --- a/src/plugins/memory_system/manually_alter_memory.py +++ b/src/plugins/memory_system/manually_alter_memory.py @@ -8,8 +8,6 @@ from rich.console import Console from Hippocampus import Hippocampus # 海马体和记忆图 -from dotenv import load_dotenv - """ 我想 总有那么一个瞬间 @@ -36,15 +34,6 @@ from src.common.database import db # noqa E402 logger = get_module_logger("mem_alter") console = Console() -# 加载环境变量 -if env_path.exists(): - logger.info(f"从 {env_path} 加载环境变量") - load_dotenv(env_path) -else: - logger.warning(f"未找到环境变量文件: {env_path}") - logger.info("将使用默认配置") - - # 查询节点信息 def query_mem_info(hippocampus: Hippocampus): while True: diff --git a/src/plugins/memory_system/offline_llm.py b/src/plugins/memory_system/offline_llm.py index fc50b17b..7cf92739 100644 --- a/src/plugins/memory_system/offline_llm.py +++ b/src/plugins/memory_system/offline_llm.py @@ -1,11 +1,11 @@ import asyncio -import os import time from typing import Tuple, Union import aiohttp import requests from src.common.logger import get_module_logger +from src.env import env logger = get_module_logger("offline_llm") @@ -14,8 +14,8 @@ class LLMRequestOff: def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name self.params = kwargs - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") + self.api_key = env.getenv("SILICONFLOW_KEY") + self.base_url = env.getenv("SILICONFLOW_BASE_URL") if not self.api_key or not self.base_url: raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index fb51539e..b5a4e0a3 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -3,6 +3,7 @@ from typing import Dict, Any, Callable, List, Set, Optional from src.common.logger import get_module_logger from src.plugins.message.message_base import MessageBase from src.common.server import global_server +from src.env import env import aiohttp import asyncio import uvicorn @@ -247,4 +248,4 @@ class MessageServer(BaseMessageHandler): raise e -global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app()) +global_api = MessageServer(host= env.getenv("HOST"), port=int( env.getenv("PORT")), app=global_server.get_app()) diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 2cab7b62..4dc60db1 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -14,7 +14,7 @@ import io import os from ...common.database import db from ...config.config import global_config - +from src.env import env logger = get_module_logger("model_utils") @@ -86,8 +86,8 @@ class LLMRequest: def __init__(self, model: dict, **kwargs): # 将大写的配置键转换为小写并从config中获取实际值 try: - self.api_key = os.environ[model["key"]] - self.base_url = os.environ[model["base_url"]] + self.api_key = env.getenv(model["key"]) + self.base_url = env.getenv(model["base_url"]) except AttributeError as e: logger.error(f"原始 model dict 信息:{model}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py index f8914c2f..e0fead41 100644 --- a/src/plugins/zhishi/knowledge_library.py +++ b/src/plugins/zhishi/knowledge_library.py @@ -1,7 +1,6 @@ import os import sys import requests -from dotenv import load_dotenv import hashlib from datetime import datetime from tqdm import tqdm @@ -14,19 +13,14 @@ sys.path.append(root_path) # 现在可以导入src模块 from src.common.database import db # noqa E402 - -# 加载根目录下的env.edv文件 -env_path = os.path.join(root_path, ".env") -if not os.path.exists(env_path): - raise FileNotFoundError(f"配置文件不存在: {env_path}") -load_dotenv(env_path) +from src.env import env class KnowledgeLibrary: def __init__(self): self.raw_info_dir = "data/raw_info" self._ensure_dirs() - self.api_key = os.getenv("SILICONFLOW_KEY") + self.api_key = env.getenv("SILICONFLOW_KEY") if not self.api_key: raise ValueError("SILICONFLOW_API_KEY 环境变量未设置") self.console = Console()