diff --git a/.gitignore b/.gitignore
index 1c3d7bd1..6a6bb683 100644
--- a/.gitignore
+++ b/.gitignore
@@ -295,3 +295,5 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk
+
+!src/env
\ No newline at end of file
diff --git a/src/api/config_api.py b/src/api/config_api.py
index e3934617..44883c87 100644
--- a/src/api/config_api.py
+++ b/src/api/config_api.py
@@ -145,11 +145,3 @@ class BotConfig:
api_urls: Dict[str, str] # API URLs
-
-@strawberry.type
-class EnvConfig:
- pass
-
- @strawberry.field
- def get_env(self) -> str:
- return "env"
diff --git a/src/common/__init__.py b/src/common/__init__.py
index 497b4a41..94aca859 100644
--- a/src/common/__init__.py
+++ b/src/common/__init__.py
@@ -1 +1,6 @@
# 这个文件可以为空,但必须存在
+from .common import *
+
+__all__ = [
+ "BASE_PATH"
+]
diff --git a/src/common/common.py b/src/common/common.py
new file mode 100644
index 00000000..8701f551
--- /dev/null
+++ b/src/common/common.py
@@ -0,0 +1,7 @@
+from pathlib import Path
+
+BASE_PATH = Path(__file__).parent.parent.parent
+"""项目根目录"""
+
+TEMPLATE_PATH = BASE_PATH / "template"
+"""模板文件夹"""
\ No newline at end of file
diff --git a/src/common/database.py b/src/common/database.py
index ee0ead0b..d35376d3 100644
--- a/src/common/database.py
+++ b/src/common/database.py
@@ -1,19 +1,18 @@
-import os
from pymongo import MongoClient
from pymongo.database import Database
-
+from src.env import env
_client = None
_db = None
def __create_database_instance():
- uri = os.getenv("MONGODB_URI")
- host = os.getenv("MONGODB_HOST", "127.0.0.1")
- port = int(os.getenv("MONGODB_PORT", "27017"))
+ uri = env.getenv("MONGODB_URI")
+ host = env.getenv("MONGODB_HOST", "127.0.0.1")
+ port = int( env.getenv("MONGODB_PORT", "27017"))
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
- username = os.getenv("MONGODB_USERNAME")
- password = os.getenv("MONGODB_PASSWORD")
- auth_source = os.getenv("MONGODB_AUTH_SOURCE")
+ username = env.getenv("MONGODB_USERNAME")
+ password = env.getenv("MONGODB_PASSWORD")
+ auth_source = env.getenv("MONGODB_AUTH_SOURCE")
if uri:
# 支持标准mongodb://和mongodb+srv://连接字符串
@@ -39,7 +38,7 @@ def get_db():
global _client, _db
if _client is None:
_client = __create_database_instance()
- _db = _client[os.getenv("DATABASE_NAME", "MegBot")]
+ _db = _client[ env.getenv("DATABASE_NAME", "MegBot")]
return _db
diff --git a/src/common/logger.py b/src/common/logger.py
index 8a5b7ffc..068073f3 100644
--- a/src/common/logger.py
+++ b/src/common/logger.py
@@ -1,15 +1,11 @@
from loguru import logger
from typing import Dict, Optional, Union, List, Tuple
import sys
-import os
from types import ModuleType
from pathlib import Path
-from dotenv import load_dotenv
+from src.env import env
# from ..plugins.chat.config import global_config
-# 加载 .env 文件
-env_path = Path(__file__).resolve().parent.parent.parent / ".env"
-load_dotenv(dotenv_path=env_path)
# 保存原生处理器ID
default_handler_id = None
@@ -32,7 +28,7 @@ _custom_style_handlers: Dict[Tuple[str, str], List[int]] = {} # 记录自定义
current_file_path = Path(__file__).resolve()
LOG_ROOT = "logs"
-SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower()
+SIMPLE_OUTPUT = env.getenv("SIMPLE_OUTPUT", "false").strip().lower()
if SIMPLE_OUTPUT == "true":
SIMPLE_OUTPUT = True
else:
@@ -625,7 +621,7 @@ def get_module_logger(
# 控制台处理器
console_id = logger.add(
sink=sys.stderr,
- level=os.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
+ level= env.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
format=current_config["console_format"],
filter=lambda record: record["extra"].get("module") == module_name and "custom_style" not in record["extra"],
enqueue=True,
@@ -640,7 +636,7 @@ def get_module_logger(
file_id = logger.add(
sink=str(log_file),
- level=os.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
+ level= env.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
format=current_config["file_format"],
rotation=current_config["rotation"],
retention=current_config["retention"],
@@ -686,7 +682,7 @@ def add_custom_style_handler(
try:
custom_console_id = logger.add(
sink=sys.stderr,
- level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level),
+ level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level),
format=console_format,
filter=lambda record: record["extra"].get("module") == module_name
and record["extra"].get("custom_style") == style_name,
@@ -710,7 +706,7 @@ def add_custom_style_handler(
# try:
# custom_file_id = logger.add(
# sink=str(log_file),
- # level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level),
+ # level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level),
# format=file_format,
# rotation=current_config["rotation"],
# retention=current_config["retention"],
@@ -753,10 +749,10 @@ def remove_module_logger(module_name: str) -> None:
# 添加全局默认处理器(只处理未注册模块的日志--->控制台)
-# print(os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"))
+# print( env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"))
DEFAULT_GLOBAL_HANDLER = logger.add(
sink=sys.stderr,
- level=os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
+ level= env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
format=(
"{time:YYYY-MM-DD HH:mm:ss} | "
"{level: <8} | "
@@ -775,7 +771,7 @@ other_log_dir.mkdir(parents=True, exist_ok=True)
DEFAULT_FILE_HANDLER = logger.add(
sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"),
- level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
+ level= env.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}",
rotation=DEFAULT_CONFIG["rotation"],
retention=DEFAULT_CONFIG["retention"],
diff --git a/src/common/server.py b/src/common/server.py
index 51799629..b1b284d9 100644
--- a/src/common/server.py
+++ b/src/common/server.py
@@ -2,10 +2,16 @@ from fastapi import FastAPI, APIRouter
from typing import Optional
from uvicorn import Config, Server as UvicornServer
import os
+from src.env import env
class Server:
- def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
+ def __init__(
+ self,
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ app_name: str = "MaiMCore",
+ ):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
@@ -46,7 +52,13 @@ class Server:
async def run(self):
"""启动服务器"""
# 禁用 uvicorn 默认日志和访问日志
- config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False)
+ config = Config(
+ app=self.app,
+ host=self._host,
+ port=self._port,
+ log_config=None,
+ access_log=False,
+ )
self._server = UvicornServer(config=config)
try:
await self._server.serve()
@@ -71,4 +83,4 @@ class Server:
return self.app
-global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
+global_server = Server(host=env.getenv("HOST"), port=int(env.getenv("PORT")))
diff --git a/src/env/__init__.py b/src/env/__init__.py
new file mode 100644
index 00000000..d8c84834
--- /dev/null
+++ b/src/env/__init__.py
@@ -0,0 +1,5 @@
+from .env import *
+
+__all__ = [
+ "env",
+]
\ No newline at end of file
diff --git a/src/env/env.py b/src/env/env.py
new file mode 100644
index 00000000..5308c589
--- /dev/null
+++ b/src/env/env.py
@@ -0,0 +1,64 @@
+import os
+import shutil
+from dotenv import load_dotenv
+
+from common.common import BASE_PATH, TEMPLATE_PATH
+
+
+class EnvConfig:
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super(EnvConfig, cls).__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if self._initialized:
+ return
+
+ self._initialized = True
+ self._load_env()
+
+ def _init_env(self):
+ # 检测.env文件是否存在
+ env_file = BASE_PATH / ".env"
+ if not env_file.exists():
+ print("检测到.env文件不存在")
+ TEMPLATE_ENV_PATH = TEMPLATE_PATH / "template.env"
+ ENV_PATH = BASE_PATH / ".env"
+ shutil.copy(TEMPLATE_ENV_PATH, ENV_PATH)
+ print(f"已从{TEMPLATE_ENV_PATH}复制创建{ENV_PATH},请修改配置后重新启动")
+
+ def _load_env(self):
+ self._init_env()
+ env_file = BASE_PATH / ".env"
+ load_dotenv(env_file)
+
+ # 根据ENVIRONMENT变量加载对应的环境文件
+ env_type = os.getenv("ENVIRONMENT", "prod")
+ if env_type == "dev":
+ env_file = BASE_PATH / ".env.dev"
+ elif env_type == "prod":
+ env_file = BASE_PATH / ".env"
+
+ if env_file.exists():
+ load_dotenv(env_file, override=True)
+
+ def get(self, key, default=None):
+ """获取环境变量"""
+ return os.getenv(key, default)
+
+ def getenv(self, key, default=None):
+ """获取环境变量"""
+ return self.get(key=key, default=default)
+
+ def get_all(self):
+ """获取所有环境变量"""
+ return dict(os.environ)
+
+
+# 创建全局实例
+env = EnvConfig()
+"""环境变量管理器"""
diff --git a/src/individuality/offline_llm.py b/src/individuality/offline_llm.py
index 2b5b6dc2..dd7bae84 100644
--- a/src/individuality/offline_llm.py
+++ b/src/individuality/offline_llm.py
@@ -1,11 +1,11 @@
import asyncio
-import os
import time
from typing import Tuple, Union
import aiohttp
import requests
from src.common.logger import get_module_logger
+from src.env import env
logger = get_module_logger("offline_llm")
@@ -14,8 +14,8 @@ class LLMRequestOff:
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
- self.api_key = os.getenv("SILICONFLOW_KEY")
- self.base_url = os.getenv("SILICONFLOW_BASE_URL")
+ self.api_key = env.getenv("SILICONFLOW_KEY")
+ self.base_url = env.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
diff --git a/src/individuality/per_bf_gen.py b/src/individuality/per_bf_gen.py
index 7e630bdd..73bee757 100644
--- a/src/individuality/per_bf_gen.py
+++ b/src/individuality/per_bf_gen.py
@@ -1,7 +1,6 @@
from typing import Dict, List
import json
import os
-from dotenv import load_dotenv
import sys
import toml
import random
@@ -21,15 +20,6 @@ from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # n
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
from src.individuality.offline_llm import LLMRequestOff # noqa E402
-# 加载环境变量
-env_path = os.path.join(root_path, ".env")
-if os.path.exists(env_path):
- print(f"从 {env_path} 加载环境变量")
- load_dotenv(env_path)
-else:
- print(f"未找到环境变量文件: {env_path}")
- print("将使用默认配置")
-
def adapt_scene(scene: str) -> str:
personality_core = config["personality"]["personality_core"]
diff --git a/src/plugins/memory_system/manually_alter_memory.py b/src/plugins/memory_system/manually_alter_memory.py
index 1452d3d5..5e33d762 100644
--- a/src/plugins/memory_system/manually_alter_memory.py
+++ b/src/plugins/memory_system/manually_alter_memory.py
@@ -8,8 +8,6 @@ from rich.console import Console
from Hippocampus import Hippocampus # 海马体和记忆图
-from dotenv import load_dotenv
-
"""
我想 总有那么一个瞬间
@@ -36,15 +34,6 @@ from src.common.database import db # noqa E402
logger = get_module_logger("mem_alter")
console = Console()
-# 加载环境变量
-if env_path.exists():
- logger.info(f"从 {env_path} 加载环境变量")
- load_dotenv(env_path)
-else:
- logger.warning(f"未找到环境变量文件: {env_path}")
- logger.info("将使用默认配置")
-
-
# 查询节点信息
def query_mem_info(hippocampus: Hippocampus):
while True:
diff --git a/src/plugins/memory_system/offline_llm.py b/src/plugins/memory_system/offline_llm.py
index fc50b17b..7cf92739 100644
--- a/src/plugins/memory_system/offline_llm.py
+++ b/src/plugins/memory_system/offline_llm.py
@@ -1,11 +1,11 @@
import asyncio
-import os
import time
from typing import Tuple, Union
import aiohttp
import requests
from src.common.logger import get_module_logger
+from src.env import env
logger = get_module_logger("offline_llm")
@@ -14,8 +14,8 @@ class LLMRequestOff:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
- self.api_key = os.getenv("SILICONFLOW_KEY")
- self.base_url = os.getenv("SILICONFLOW_BASE_URL")
+ self.api_key = env.getenv("SILICONFLOW_KEY")
+ self.base_url = env.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py
index fb51539e..b5a4e0a3 100644
--- a/src/plugins/message/api.py
+++ b/src/plugins/message/api.py
@@ -3,6 +3,7 @@ from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase
from src.common.server import global_server
+from src.env import env
import aiohttp
import asyncio
import uvicorn
@@ -247,4 +248,4 @@ class MessageServer(BaseMessageHandler):
raise e
-global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
+global_api = MessageServer(host= env.getenv("HOST"), port=int( env.getenv("PORT")), app=global_server.get_app())
diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py
index 2cab7b62..4dc60db1 100644
--- a/src/plugins/models/utils_model.py
+++ b/src/plugins/models/utils_model.py
@@ -14,7 +14,7 @@ import io
import os
from ...common.database import db
from ...config.config import global_config
-
+from src.env import env
logger = get_module_logger("model_utils")
@@ -86,8 +86,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
- self.api_key = os.environ[model["key"]]
- self.base_url = os.environ[model["base_url"]]
+ self.api_key = env.getenv(model["key"])
+ self.base_url = env.getenv(model["base_url"])
except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py
index f8914c2f..e0fead41 100644
--- a/src/plugins/zhishi/knowledge_library.py
+++ b/src/plugins/zhishi/knowledge_library.py
@@ -1,7 +1,6 @@
import os
import sys
import requests
-from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
@@ -14,19 +13,14 @@ sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import db # noqa E402
-
-# 加载根目录下的env.edv文件
-env_path = os.path.join(root_path, ".env")
-if not os.path.exists(env_path):
- raise FileNotFoundError(f"配置文件不存在: {env_path}")
-load_dotenv(env_path)
+from src.env import env
class KnowledgeLibrary:
def __init__(self):
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
- self.api_key = os.getenv("SILICONFLOW_KEY")
+ self.api_key = env.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()