fix: 统一管理环境变量,避免环境变量未导入前使用环境变量以及重复导入.env

pull/840/head
Night-stars-1 2025-04-24 22:50:50 +08:00
parent 4f6ef7b0a7
commit 26339453fe
16 changed files with 128 additions and 72 deletions

2
.gitignore vendored
View File

@ -295,3 +295,5 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk
!src/env

View File

@ -145,11 +145,3 @@ class BotConfig:
api_urls: Dict[str, str] # API URLs
@strawberry.type
class EnvConfig:
pass
@strawberry.field
def get_env(self) -> str:
return "env"

View File

@ -1 +1,6 @@
# 这个文件可以为空,但必须存在
from .common import *
__all__ = [
"BASE_PATH"
]

View File

@ -0,0 +1,7 @@
from pathlib import Path
BASE_PATH = Path(__file__).parent.parent.parent
"""项目根目录"""
TEMPLATE_PATH = BASE_PATH / "template"
"""模板文件夹"""

View File

@ -1,19 +1,18 @@
import os
from pymongo import MongoClient
from pymongo.database import Database
from src.env import env
_client = None
_db = None
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
host = os.getenv("MONGODB_HOST", "127.0.0.1")
port = int(os.getenv("MONGODB_PORT", "27017"))
uri = env.getenv("MONGODB_URI")
host = env.getenv("MONGODB_HOST", "127.0.0.1")
port = int( env.getenv("MONGODB_PORT", "27017"))
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
username = os.getenv("MONGODB_USERNAME")
password = os.getenv("MONGODB_PASSWORD")
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
username = env.getenv("MONGODB_USERNAME")
password = env.getenv("MONGODB_PASSWORD")
auth_source = env.getenv("MONGODB_AUTH_SOURCE")
if uri:
# 支持标准mongodb://和mongodb+srv://连接字符串
@ -39,7 +38,7 @@ def get_db():
global _client, _db
if _client is None:
_client = __create_database_instance()
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
_db = _client[ env.getenv("DATABASE_NAME", "MegBot")]
return _db

View File

@ -1,15 +1,11 @@
from loguru import logger
from typing import Dict, Optional, Union, List, Tuple
import sys
import os
from types import ModuleType
from pathlib import Path
from dotenv import load_dotenv
from src.env import env
# from ..plugins.chat.config import global_config
# 加载 .env 文件
env_path = Path(__file__).resolve().parent.parent.parent / ".env"
load_dotenv(dotenv_path=env_path)
# 保存原生处理器ID
default_handler_id = None
@ -32,7 +28,7 @@ _custom_style_handlers: Dict[Tuple[str, str], List[int]] = {} # 记录自定义
current_file_path = Path(__file__).resolve()
LOG_ROOT = "logs"
SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower()
SIMPLE_OUTPUT = env.getenv("SIMPLE_OUTPUT", "false").strip().lower()
if SIMPLE_OUTPUT == "true":
SIMPLE_OUTPUT = True
else:
@ -625,7 +621,7 @@ def get_module_logger(
# 控制台处理器
console_id = logger.add(
sink=sys.stderr,
level=os.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
level= env.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
format=current_config["console_format"],
filter=lambda record: record["extra"].get("module") == module_name and "custom_style" not in record["extra"],
enqueue=True,
@ -640,7 +636,7 @@ def get_module_logger(
file_id = logger.add(
sink=str(log_file),
level=os.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
level= env.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
format=current_config["file_format"],
rotation=current_config["rotation"],
retention=current_config["retention"],
@ -686,7 +682,7 @@ def add_custom_style_handler(
try:
custom_console_id = logger.add(
sink=sys.stderr,
level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level),
level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level),
format=console_format,
filter=lambda record: record["extra"].get("module") == module_name
and record["extra"].get("custom_style") == style_name,
@ -710,7 +706,7 @@ def add_custom_style_handler(
# try:
# custom_file_id = logger.add(
# sink=str(log_file),
# level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level),
# level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level),
# format=file_format,
# rotation=current_config["rotation"],
# retention=current_config["retention"],
@ -753,10 +749,10 @@ def remove_module_logger(module_name: str) -> None:
# 添加全局默认处理器(只处理未注册模块的日志--->控制台)
# print(os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"))
# print( env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"))
DEFAULT_GLOBAL_HANDLER = logger.add(
sink=sys.stderr,
level=os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
level= env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
format=(
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
@ -775,7 +771,7 @@ other_log_dir.mkdir(parents=True, exist_ok=True)
DEFAULT_FILE_HANDLER = logger.add(
sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"),
level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
level= env.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}",
rotation=DEFAULT_CONFIG["rotation"],
retention=DEFAULT_CONFIG["retention"],

View File

@ -2,10 +2,16 @@ from fastapi import FastAPI, APIRouter
from typing import Optional
from uvicorn import Config, Server as UvicornServer
import os
from src.env import env
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
app_name: str = "MaiMCore",
):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
@ -46,7 +52,13 @@ class Server:
async def run(self):
"""启动服务器"""
# 禁用 uvicorn 默认日志和访问日志
config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False)
config = Config(
app=self.app,
host=self._host,
port=self._port,
log_config=None,
access_log=False,
)
self._server = UvicornServer(config=config)
try:
await self._server.serve()
@ -71,4 +83,4 @@ class Server:
return self.app
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
global_server = Server(host=env.getenv("HOST"), port=int(env.getenv("PORT")))

5
src/env/__init__.py vendored 100644
View File

@ -0,0 +1,5 @@
from .env import *
__all__ = [
"env",
]

64
src/env/env.py vendored 100644
View File

@ -0,0 +1,64 @@
import os
import shutil
from dotenv import load_dotenv
from common.common import BASE_PATH, TEMPLATE_PATH
class EnvConfig:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(EnvConfig, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._load_env()
def _init_env(self):
# 检测.env文件是否存在
env_file = BASE_PATH / ".env"
if not env_file.exists():
print("检测到.env文件不存在")
TEMPLATE_ENV_PATH = TEMPLATE_PATH / "template.env"
ENV_PATH = BASE_PATH / ".env"
shutil.copy(TEMPLATE_ENV_PATH, ENV_PATH)
print(f"已从{TEMPLATE_ENV_PATH}复制创建{ENV_PATH},请修改配置后重新启动")
def _load_env(self):
self._init_env()
env_file = BASE_PATH / ".env"
load_dotenv(env_file)
# 根据ENVIRONMENT变量加载对应的环境文件
env_type = os.getenv("ENVIRONMENT", "prod")
if env_type == "dev":
env_file = BASE_PATH / ".env.dev"
elif env_type == "prod":
env_file = BASE_PATH / ".env"
if env_file.exists():
load_dotenv(env_file, override=True)
def get(self, key, default=None):
"""获取环境变量"""
return os.getenv(key, default)
def getenv(self, key, default=None):
"""获取环境变量"""
return self.get(key=key, default=default)
def get_all(self):
"""获取所有环境变量"""
return dict(os.environ)
# 创建全局实例
env = EnvConfig()
"""环境变量管理器"""

View File

@ -1,11 +1,11 @@
import asyncio
import os
import time
from typing import Tuple, Union
import aiohttp
import requests
from src.common.logger import get_module_logger
from src.env import env
logger = get_module_logger("offline_llm")
@ -14,8 +14,8 @@ class LLMRequestOff:
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
self.api_key = env.getenv("SILICONFLOW_KEY")
self.base_url = env.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")

View File

@ -1,7 +1,6 @@
from typing import Dict, List
import json
import os
from dotenv import load_dotenv
import sys
import toml
import random
@ -21,15 +20,6 @@ from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # n
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
from src.individuality.offline_llm import LLMRequestOff # noqa E402
# 加载环境变量
env_path = os.path.join(root_path, ".env")
if os.path.exists(env_path):
print(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
print(f"未找到环境变量文件: {env_path}")
print("将使用默认配置")
def adapt_scene(scene: str) -> str:
personality_core = config["personality"]["personality_core"]

View File

@ -8,8 +8,6 @@ from rich.console import Console
from Hippocampus import Hippocampus # 海马体和记忆图
from dotenv import load_dotenv
"""
我想 总有那么一个瞬间
@ -36,15 +34,6 @@ from src.common.database import db # noqa E402
logger = get_module_logger("mem_alter")
console = Console()
# 加载环境变量
if env_path.exists():
logger.info(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
# 查询节点信息
def query_mem_info(hippocampus: Hippocampus):
while True:

View File

@ -1,11 +1,11 @@
import asyncio
import os
import time
from typing import Tuple, Union
import aiohttp
import requests
from src.common.logger import get_module_logger
from src.env import env
logger = get_module_logger("offline_llm")
@ -14,8 +14,8 @@ class LLMRequestOff:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
self.api_key = env.getenv("SILICONFLOW_KEY")
self.base_url = env.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")

View File

@ -3,6 +3,7 @@ from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase
from src.common.server import global_server
from src.env import env
import aiohttp
import asyncio
import uvicorn
@ -247,4 +248,4 @@ class MessageServer(BaseMessageHandler):
raise e
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
global_api = MessageServer(host= env.getenv("HOST"), port=int( env.getenv("PORT")), app=global_server.get_app())

View File

@ -14,7 +14,7 @@ import io
import os
from ...common.database import db
from ...config.config import global_config
from src.env import env
logger = get_module_logger("model_utils")
@ -86,8 +86,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = os.environ[model["key"]]
self.base_url = os.environ[model["base_url"]]
self.api_key = env.getenv(model["key"])
self.base_url = env.getenv(model["base_url"])
except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")

View File

@ -1,7 +1,6 @@
import os
import sys
import requests
from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
@ -14,19 +13,14 @@ sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import db # noqa E402
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
from src.env import env
class KnowledgeLibrary:
def __init__(self):
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
self.api_key = env.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()