fix: 统一管理环境变量,避免环境变量未导入前使用环境变量以及重复导入.env

pull/840/head
Night-stars-1 2025-04-24 22:50:50 +08:00
parent 4f6ef7b0a7
commit 26339453fe
16 changed files with 128 additions and 72 deletions

2
.gitignore vendored
View File

@ -295,3 +295,5 @@ $RECYCLE.BIN/
# Windows shortcuts # Windows shortcuts
*.lnk *.lnk
!src/env

View File

@ -145,11 +145,3 @@ class BotConfig:
api_urls: Dict[str, str] # API URLs api_urls: Dict[str, str] # API URLs
@strawberry.type
class EnvConfig:
pass
@strawberry.field
def get_env(self) -> str:
return "env"

View File

@ -1 +1,6 @@
# 这个文件可以为空,但必须存在 # 这个文件可以为空,但必须存在
from .common import *
__all__ = [
"BASE_PATH"
]

View File

@ -0,0 +1,7 @@
from pathlib import Path
BASE_PATH = Path(__file__).parent.parent.parent
"""项目根目录"""
TEMPLATE_PATH = BASE_PATH / "template"
"""模板文件夹"""

View File

@ -1,19 +1,18 @@
import os
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.database import Database from pymongo.database import Database
from src.env import env
_client = None _client = None
_db = None _db = None
def __create_database_instance(): def __create_database_instance():
uri = os.getenv("MONGODB_URI") uri = env.getenv("MONGODB_URI")
host = os.getenv("MONGODB_HOST", "127.0.0.1") host = env.getenv("MONGODB_HOST", "127.0.0.1")
port = int(os.getenv("MONGODB_PORT", "27017")) port = int( env.getenv("MONGODB_PORT", "27017"))
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用 # db_name 变量在创建连接时不需要,在获取数据库实例时才使用
username = os.getenv("MONGODB_USERNAME") username = env.getenv("MONGODB_USERNAME")
password = os.getenv("MONGODB_PASSWORD") password = env.getenv("MONGODB_PASSWORD")
auth_source = os.getenv("MONGODB_AUTH_SOURCE") auth_source = env.getenv("MONGODB_AUTH_SOURCE")
if uri: if uri:
# 支持标准mongodb://和mongodb+srv://连接字符串 # 支持标准mongodb://和mongodb+srv://连接字符串
@ -39,7 +38,7 @@ def get_db():
global _client, _db global _client, _db
if _client is None: if _client is None:
_client = __create_database_instance() _client = __create_database_instance()
_db = _client[os.getenv("DATABASE_NAME", "MegBot")] _db = _client[ env.getenv("DATABASE_NAME", "MegBot")]
return _db return _db

View File

@ -1,15 +1,11 @@
from loguru import logger from loguru import logger
from typing import Dict, Optional, Union, List, Tuple from typing import Dict, Optional, Union, List, Tuple
import sys import sys
import os
from types import ModuleType from types import ModuleType
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from src.env import env
# from ..plugins.chat.config import global_config # from ..plugins.chat.config import global_config
# 加载 .env 文件
env_path = Path(__file__).resolve().parent.parent.parent / ".env"
load_dotenv(dotenv_path=env_path)
# 保存原生处理器ID # 保存原生处理器ID
default_handler_id = None default_handler_id = None
@ -32,7 +28,7 @@ _custom_style_handlers: Dict[Tuple[str, str], List[int]] = {} # 记录自定义
current_file_path = Path(__file__).resolve() current_file_path = Path(__file__).resolve()
LOG_ROOT = "logs" LOG_ROOT = "logs"
SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower() SIMPLE_OUTPUT = env.getenv("SIMPLE_OUTPUT", "false").strip().lower()
if SIMPLE_OUTPUT == "true": if SIMPLE_OUTPUT == "true":
SIMPLE_OUTPUT = True SIMPLE_OUTPUT = True
else: else:
@ -625,7 +621,7 @@ def get_module_logger(
# 控制台处理器 # 控制台处理器
console_id = logger.add( console_id = logger.add(
sink=sys.stderr, sink=sys.stderr,
level=os.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]), level= env.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
format=current_config["console_format"], format=current_config["console_format"],
filter=lambda record: record["extra"].get("module") == module_name and "custom_style" not in record["extra"], filter=lambda record: record["extra"].get("module") == module_name and "custom_style" not in record["extra"],
enqueue=True, enqueue=True,
@ -640,7 +636,7 @@ def get_module_logger(
file_id = logger.add( file_id = logger.add(
sink=str(log_file), sink=str(log_file),
level=os.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]), level= env.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
format=current_config["file_format"], format=current_config["file_format"],
rotation=current_config["rotation"], rotation=current_config["rotation"],
retention=current_config["retention"], retention=current_config["retention"],
@ -686,7 +682,7 @@ def add_custom_style_handler(
try: try:
custom_console_id = logger.add( custom_console_id = logger.add(
sink=sys.stderr, sink=sys.stderr,
level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level), level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level),
format=console_format, format=console_format,
filter=lambda record: record["extra"].get("module") == module_name filter=lambda record: record["extra"].get("module") == module_name
and record["extra"].get("custom_style") == style_name, and record["extra"].get("custom_style") == style_name,
@ -710,7 +706,7 @@ def add_custom_style_handler(
# try: # try:
# custom_file_id = logger.add( # custom_file_id = logger.add(
# sink=str(log_file), # sink=str(log_file),
# level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level), # level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level),
# format=file_format, # format=file_format,
# rotation=current_config["rotation"], # rotation=current_config["rotation"],
# retention=current_config["retention"], # retention=current_config["retention"],
@ -753,10 +749,10 @@ def remove_module_logger(module_name: str) -> None:
# 添加全局默认处理器(只处理未注册模块的日志--->控制台) # 添加全局默认处理器(只处理未注册模块的日志--->控制台)
# print(os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS")) # print( env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"))
DEFAULT_GLOBAL_HANDLER = logger.add( DEFAULT_GLOBAL_HANDLER = logger.add(
sink=sys.stderr, sink=sys.stderr,
level=os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"), level= env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
format=( format=(
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | " "<level>{level: <8}</level> | "
@ -775,7 +771,7 @@ other_log_dir.mkdir(parents=True, exist_ok=True)
DEFAULT_FILE_HANDLER = logger.add( DEFAULT_FILE_HANDLER = logger.add(
sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"), sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"),
level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"), level= env.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}", format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}",
rotation=DEFAULT_CONFIG["rotation"], rotation=DEFAULT_CONFIG["rotation"],
retention=DEFAULT_CONFIG["retention"], retention=DEFAULT_CONFIG["retention"],

View File

@ -2,10 +2,16 @@ from fastapi import FastAPI, APIRouter
from typing import Optional from typing import Optional
from uvicorn import Config, Server as UvicornServer from uvicorn import Config, Server as UvicornServer
import os import os
from src.env import env
class Server: class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
app_name: str = "MaiMCore",
):
self.app = FastAPI(title=app_name) self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1" self._host: str = "127.0.0.1"
self._port: int = 8080 self._port: int = 8080
@ -46,7 +52,13 @@ class Server:
async def run(self): async def run(self):
"""启动服务器""" """启动服务器"""
# 禁用 uvicorn 默认日志和访问日志 # 禁用 uvicorn 默认日志和访问日志
config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False) config = Config(
app=self.app,
host=self._host,
port=self._port,
log_config=None,
access_log=False,
)
self._server = UvicornServer(config=config) self._server = UvicornServer(config=config)
try: try:
await self._server.serve() await self._server.serve()
@ -71,4 +83,4 @@ class Server:
return self.app return self.app
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"])) global_server = Server(host=env.getenv("HOST"), port=int(env.getenv("PORT")))

5
src/env/__init__.py vendored 100644
View File

@ -0,0 +1,5 @@
from .env import *
__all__ = [
"env",
]

64
src/env/env.py vendored 100644
View File

@ -0,0 +1,64 @@
import os
import shutil
from dotenv import load_dotenv
from common.common import BASE_PATH, TEMPLATE_PATH
class EnvConfig:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(EnvConfig, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._load_env()
def _init_env(self):
# 检测.env文件是否存在
env_file = BASE_PATH / ".env"
if not env_file.exists():
print("检测到.env文件不存在")
TEMPLATE_ENV_PATH = TEMPLATE_PATH / "template.env"
ENV_PATH = BASE_PATH / ".env"
shutil.copy(TEMPLATE_ENV_PATH, ENV_PATH)
print(f"已从{TEMPLATE_ENV_PATH}复制创建{ENV_PATH},请修改配置后重新启动")
def _load_env(self):
self._init_env()
env_file = BASE_PATH / ".env"
load_dotenv(env_file)
# 根据ENVIRONMENT变量加载对应的环境文件
env_type = os.getenv("ENVIRONMENT", "prod")
if env_type == "dev":
env_file = BASE_PATH / ".env.dev"
elif env_type == "prod":
env_file = BASE_PATH / ".env"
if env_file.exists():
load_dotenv(env_file, override=True)
def get(self, key, default=None):
"""获取环境变量"""
return os.getenv(key, default)
def getenv(self, key, default=None):
"""获取环境变量"""
return self.get(key=key, default=default)
def get_all(self):
"""获取所有环境变量"""
return dict(os.environ)
# 创建全局实例
env = EnvConfig()
"""环境变量管理器"""

View File

@ -1,11 +1,11 @@
import asyncio import asyncio
import os
import time import time
from typing import Tuple, Union from typing import Tuple, Union
import aiohttp import aiohttp
import requests import requests
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.env import env
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")
@ -14,8 +14,8 @@ class LLMRequestOff:
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY") self.api_key = env.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL") self.base_url = env.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url: if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")

View File

@ -1,7 +1,6 @@
from typing import Dict, List from typing import Dict, List
import json import json
import os import os
from dotenv import load_dotenv
import sys import sys
import toml import toml
import random import random
@ -21,15 +20,6 @@ from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # n
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402 from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
from src.individuality.offline_llm import LLMRequestOff # noqa E402 from src.individuality.offline_llm import LLMRequestOff # noqa E402
# 加载环境变量
env_path = os.path.join(root_path, ".env")
if os.path.exists(env_path):
print(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
print(f"未找到环境变量文件: {env_path}")
print("将使用默认配置")
def adapt_scene(scene: str) -> str: def adapt_scene(scene: str) -> str:
personality_core = config["personality"]["personality_core"] personality_core = config["personality"]["personality_core"]

View File

@ -8,8 +8,6 @@ from rich.console import Console
from Hippocampus import Hippocampus # 海马体和记忆图 from Hippocampus import Hippocampus # 海马体和记忆图
from dotenv import load_dotenv
""" """
我想 总有那么一个瞬间 我想 总有那么一个瞬间
@ -36,15 +34,6 @@ from src.common.database import db # noqa E402
logger = get_module_logger("mem_alter") logger = get_module_logger("mem_alter")
console = Console() console = Console()
# 加载环境变量
if env_path.exists():
logger.info(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
# 查询节点信息 # 查询节点信息
def query_mem_info(hippocampus: Hippocampus): def query_mem_info(hippocampus: Hippocampus):
while True: while True:

View File

@ -1,11 +1,11 @@
import asyncio import asyncio
import os
import time import time
from typing import Tuple, Union from typing import Tuple, Union
import aiohttp import aiohttp
import requests import requests
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.env import env
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")
@ -14,8 +14,8 @@ class LLMRequestOff:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY") self.api_key = env.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL") self.base_url = env.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url: if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")

View File

@ -3,6 +3,7 @@ from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase from src.plugins.message.message_base import MessageBase
from src.common.server import global_server from src.common.server import global_server
from src.env import env
import aiohttp import aiohttp
import asyncio import asyncio
import uvicorn import uvicorn
@ -247,4 +248,4 @@ class MessageServer(BaseMessageHandler):
raise e raise e
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app()) global_api = MessageServer(host= env.getenv("HOST"), port=int( env.getenv("PORT")), app=global_server.get_app())

View File

@ -14,7 +14,7 @@ import io
import os import os
from ...common.database import db from ...common.database import db
from ...config.config import global_config from ...config.config import global_config
from src.env import env
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")
@ -86,8 +86,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs): def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
self.api_key = os.environ[model["key"]] self.api_key = env.getenv(model["key"])
self.base_url = os.environ[model["base_url"]] self.base_url = env.getenv(model["base_url"])
except AttributeError as e: except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")

View File

@ -1,7 +1,6 @@
import os import os
import sys import sys
import requests import requests
from dotenv import load_dotenv
import hashlib import hashlib
from datetime import datetime from datetime import datetime
from tqdm import tqdm from tqdm import tqdm
@ -14,19 +13,14 @@ sys.path.append(root_path)
# 现在可以导入src模块 # 现在可以导入src模块
from src.common.database import db # noqa E402 from src.common.database import db # noqa E402
from src.env import env
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
class KnowledgeLibrary: class KnowledgeLibrary:
def __init__(self): def __init__(self):
self.raw_info_dir = "data/raw_info" self.raw_info_dir = "data/raw_info"
self._ensure_dirs() self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY") self.api_key = env.getenv("SILICONFLOW_KEY")
if not self.api_key: if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置") raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console() self.console = Console()