mirror of https://github.com/Mai-with-u/MaiBot.git
fix: 统一管理环境变量,避免环境变量未导入前使用环境变量以及重复导入.env
parent
4f6ef7b0a7
commit
26339453fe
|
|
@ -295,3 +295,5 @@ $RECYCLE.BIN/
|
|||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
!src/env
|
||||
|
|
@ -145,11 +145,3 @@ class BotConfig:
|
|||
|
||||
api_urls: Dict[str, str] # API URLs
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class EnvConfig:
|
||||
pass
|
||||
|
||||
@strawberry.field
|
||||
def get_env(self) -> str:
|
||||
return "env"
|
||||
|
|
|
|||
|
|
@ -1 +1,6 @@
|
|||
# 这个文件可以为空,但必须存在
|
||||
from .common import *
|
||||
|
||||
__all__ = [
|
||||
"BASE_PATH"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
BASE_PATH = Path(__file__).parent.parent.parent
|
||||
"""项目根目录"""
|
||||
|
||||
TEMPLATE_PATH = BASE_PATH / "template"
|
||||
"""模板文件夹"""
|
||||
|
|
@ -1,19 +1,18 @@
|
|||
import os
|
||||
from pymongo import MongoClient
|
||||
from pymongo.database import Database
|
||||
|
||||
from src.env import env
|
||||
_client = None
|
||||
_db = None
|
||||
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("MONGODB_PORT", "27017"))
|
||||
uri = env.getenv("MONGODB_URI")
|
||||
host = env.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
port = int( env.getenv("MONGODB_PORT", "27017"))
|
||||
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
|
||||
username = os.getenv("MONGODB_USERNAME")
|
||||
password = os.getenv("MONGODB_PASSWORD")
|
||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
|
||||
username = env.getenv("MONGODB_USERNAME")
|
||||
password = env.getenv("MONGODB_PASSWORD")
|
||||
auth_source = env.getenv("MONGODB_AUTH_SOURCE")
|
||||
|
||||
if uri:
|
||||
# 支持标准mongodb://和mongodb+srv://连接字符串
|
||||
|
|
@ -39,7 +38,7 @@ def get_db():
|
|||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
|
||||
_db = _client[ env.getenv("DATABASE_NAME", "MegBot")]
|
||||
return _db
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
from loguru import logger
|
||||
from typing import Dict, Optional, Union, List, Tuple
|
||||
import sys
|
||||
import os
|
||||
from types import ModuleType
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from src.env import env
|
||||
# from ..plugins.chat.config import global_config
|
||||
|
||||
# 加载 .env 文件
|
||||
env_path = Path(__file__).resolve().parent.parent.parent / ".env"
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
# 保存原生处理器ID
|
||||
default_handler_id = None
|
||||
|
|
@ -32,7 +28,7 @@ _custom_style_handlers: Dict[Tuple[str, str], List[int]] = {} # 记录自定义
|
|||
current_file_path = Path(__file__).resolve()
|
||||
LOG_ROOT = "logs"
|
||||
|
||||
SIMPLE_OUTPUT = os.getenv("SIMPLE_OUTPUT", "false").strip().lower()
|
||||
SIMPLE_OUTPUT = env.getenv("SIMPLE_OUTPUT", "false").strip().lower()
|
||||
if SIMPLE_OUTPUT == "true":
|
||||
SIMPLE_OUTPUT = True
|
||||
else:
|
||||
|
|
@ -625,7 +621,7 @@ def get_module_logger(
|
|||
# 控制台处理器
|
||||
console_id = logger.add(
|
||||
sink=sys.stderr,
|
||||
level=os.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
|
||||
level= env.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
|
||||
format=current_config["console_format"],
|
||||
filter=lambda record: record["extra"].get("module") == module_name and "custom_style" not in record["extra"],
|
||||
enqueue=True,
|
||||
|
|
@ -640,7 +636,7 @@ def get_module_logger(
|
|||
|
||||
file_id = logger.add(
|
||||
sink=str(log_file),
|
||||
level=os.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
|
||||
level= env.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
|
||||
format=current_config["file_format"],
|
||||
rotation=current_config["rotation"],
|
||||
retention=current_config["retention"],
|
||||
|
|
@ -686,7 +682,7 @@ def add_custom_style_handler(
|
|||
try:
|
||||
custom_console_id = logger.add(
|
||||
sink=sys.stderr,
|
||||
level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level),
|
||||
level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_CONSOLE_LEVEL", console_level),
|
||||
format=console_format,
|
||||
filter=lambda record: record["extra"].get("module") == module_name
|
||||
and record["extra"].get("custom_style") == style_name,
|
||||
|
|
@ -710,7 +706,7 @@ def add_custom_style_handler(
|
|||
# try:
|
||||
# custom_file_id = logger.add(
|
||||
# sink=str(log_file),
|
||||
# level=os.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level),
|
||||
# level= env.getenv(f"{module_name.upper()}_{style_name.upper()}_FILE_LEVEL", file_level),
|
||||
# format=file_format,
|
||||
# rotation=current_config["rotation"],
|
||||
# retention=current_config["retention"],
|
||||
|
|
@ -753,10 +749,10 @@ def remove_module_logger(module_name: str) -> None:
|
|||
|
||||
|
||||
# 添加全局默认处理器(只处理未注册模块的日志--->控制台)
|
||||
# print(os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"))
|
||||
# print( env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"))
|
||||
DEFAULT_GLOBAL_HANDLER = logger.add(
|
||||
sink=sys.stderr,
|
||||
level=os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
|
||||
level= env.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
|
||||
format=(
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
|
|
@ -775,7 +771,7 @@ other_log_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
||||
DEFAULT_FILE_HANDLER = logger.add(
|
||||
sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"),
|
||||
level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
|
||||
level= env.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}",
|
||||
rotation=DEFAULT_CONFIG["rotation"],
|
||||
retention=DEFAULT_CONFIG["retention"],
|
||||
|
|
|
|||
|
|
@ -2,10 +2,16 @@ from fastapi import FastAPI, APIRouter
|
|||
from typing import Optional
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
import os
|
||||
from src.env import env
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||
def __init__(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
app_name: str = "MaiMCore",
|
||||
):
|
||||
self.app = FastAPI(title=app_name)
|
||||
self._host: str = "127.0.0.1"
|
||||
self._port: int = 8080
|
||||
|
|
@ -46,7 +52,13 @@ class Server:
|
|||
async def run(self):
|
||||
"""启动服务器"""
|
||||
# 禁用 uvicorn 默认日志和访问日志
|
||||
config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False)
|
||||
config = Config(
|
||||
app=self.app,
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
log_config=None,
|
||||
access_log=False,
|
||||
)
|
||||
self._server = UvicornServer(config=config)
|
||||
try:
|
||||
await self._server.serve()
|
||||
|
|
@ -71,4 +83,4 @@ class Server:
|
|||
return self.app
|
||||
|
||||
|
||||
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
||||
global_server = Server(host=env.getenv("HOST"), port=int(env.getenv("PORT")))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from .env import *
|
||||
|
||||
__all__ = [
|
||||
"env",
|
||||
]
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import os
|
||||
import shutil
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from common.common import BASE_PATH, TEMPLATE_PATH
|
||||
|
||||
|
||||
class EnvConfig:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(EnvConfig, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self._load_env()
|
||||
|
||||
def _init_env(self):
|
||||
# 检测.env文件是否存在
|
||||
env_file = BASE_PATH / ".env"
|
||||
if not env_file.exists():
|
||||
print("检测到.env文件不存在")
|
||||
TEMPLATE_ENV_PATH = TEMPLATE_PATH / "template.env"
|
||||
ENV_PATH = BASE_PATH / ".env"
|
||||
shutil.copy(TEMPLATE_ENV_PATH, ENV_PATH)
|
||||
print(f"已从{TEMPLATE_ENV_PATH}复制创建{ENV_PATH},请修改配置后重新启动")
|
||||
|
||||
def _load_env(self):
|
||||
self._init_env()
|
||||
env_file = BASE_PATH / ".env"
|
||||
load_dotenv(env_file)
|
||||
|
||||
# 根据ENVIRONMENT变量加载对应的环境文件
|
||||
env_type = os.getenv("ENVIRONMENT", "prod")
|
||||
if env_type == "dev":
|
||||
env_file = BASE_PATH / ".env.dev"
|
||||
elif env_type == "prod":
|
||||
env_file = BASE_PATH / ".env"
|
||||
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file, override=True)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""获取环境变量"""
|
||||
return os.getenv(key, default)
|
||||
|
||||
def getenv(self, key, default=None):
|
||||
"""获取环境变量"""
|
||||
return self.get(key=key, default=default)
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有环境变量"""
|
||||
return dict(os.environ)
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
env = EnvConfig()
|
||||
"""环境变量管理器"""
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from src.common.logger import get_module_logger
|
||||
from src.env import env
|
||||
|
||||
logger = get_module_logger("offline_llm")
|
||||
|
||||
|
|
@ -14,8 +14,8 @@ class LLMRequestOff:
|
|||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
self.api_key = env.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = env.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
if not self.api_key or not self.base_url:
|
||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Dict, List
|
||||
import json
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
import toml
|
||||
import random
|
||||
|
|
@ -21,15 +20,6 @@ from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # n
|
|||
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
|
||||
from src.individuality.offline_llm import LLMRequestOff # noqa E402
|
||||
|
||||
# 加载环境变量
|
||||
env_path = os.path.join(root_path, ".env")
|
||||
if os.path.exists(env_path):
|
||||
print(f"从 {env_path} 加载环境变量")
|
||||
load_dotenv(env_path)
|
||||
else:
|
||||
print(f"未找到环境变量文件: {env_path}")
|
||||
print("将使用默认配置")
|
||||
|
||||
|
||||
def adapt_scene(scene: str) -> str:
|
||||
personality_core = config["personality"]["personality_core"]
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ from rich.console import Console
|
|||
from Hippocampus import Hippocampus # 海马体和记忆图
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
"""
|
||||
我想 总有那么一个瞬间
|
||||
|
|
@ -36,15 +34,6 @@ from src.common.database import db # noqa E402
|
|||
logger = get_module_logger("mem_alter")
|
||||
console = Console()
|
||||
|
||||
# 加载环境变量
|
||||
if env_path.exists():
|
||||
logger.info(f"从 {env_path} 加载环境变量")
|
||||
load_dotenv(env_path)
|
||||
else:
|
||||
logger.warning(f"未找到环境变量文件: {env_path}")
|
||||
logger.info("将使用默认配置")
|
||||
|
||||
|
||||
# 查询节点信息
|
||||
def query_mem_info(hippocampus: Hippocampus):
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from src.common.logger import get_module_logger
|
||||
from src.env import env
|
||||
|
||||
logger = get_module_logger("offline_llm")
|
||||
|
||||
|
|
@ -14,8 +14,8 @@ class LLMRequestOff:
|
|||
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
self.api_key = env.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = env.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
if not self.api_key or not self.base_url:
|
||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Dict, Any, Callable, List, Set, Optional
|
|||
from src.common.logger import get_module_logger
|
||||
from src.plugins.message.message_base import MessageBase
|
||||
from src.common.server import global_server
|
||||
from src.env import env
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import uvicorn
|
||||
|
|
@ -247,4 +248,4 @@ class MessageServer(BaseMessageHandler):
|
|||
raise e
|
||||
|
||||
|
||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
||||
global_api = MessageServer(host= env.getenv("HOST"), port=int( env.getenv("PORT")), app=global_server.get_app())
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import io
|
|||
import os
|
||||
from ...common.database import db
|
||||
from ...config.config import global_config
|
||||
|
||||
from src.env import env
|
||||
logger = get_module_logger("model_utils")
|
||||
|
||||
|
||||
|
|
@ -86,8 +86,8 @@ class LLMRequest:
|
|||
def __init__(self, model: dict, **kwargs):
|
||||
# 将大写的配置键转换为小写并从config中获取实际值
|
||||
try:
|
||||
self.api_key = os.environ[model["key"]]
|
||||
self.base_url = os.environ[model["base_url"]]
|
||||
self.api_key = env.getenv(model["key"])
|
||||
self.base_url = env.getenv(model["base_url"])
|
||||
except AttributeError as e:
|
||||
logger.error(f"原始 model dict 信息:{model}")
|
||||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
|
|
@ -14,19 +13,14 @@ sys.path.append(root_path)
|
|||
|
||||
# 现在可以导入src模块
|
||||
from src.common.database import db # noqa E402
|
||||
|
||||
# 加载根目录下的env.edv文件
|
||||
env_path = os.path.join(root_path, ".env")
|
||||
if not os.path.exists(env_path):
|
||||
raise FileNotFoundError(f"配置文件不存在: {env_path}")
|
||||
load_dotenv(env_path)
|
||||
from src.env import env
|
||||
|
||||
|
||||
class KnowledgeLibrary:
|
||||
def __init__(self):
|
||||
self.raw_info_dir = "data/raw_info"
|
||||
self._ensure_dirs()
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.api_key = env.getenv("SILICONFLOW_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
|
||||
self.console = Console()
|
||||
|
|
|
|||
Loading…
Reference in New Issue