mirror of https://github.com/Mai-with-u/MaiBot.git
恢复可用性
parent
4855cbc265
commit
3a66bfeac1
6
bot.py
6
bot.py
|
|
@ -44,7 +44,11 @@ logger = get_logger("main")
|
|||
|
||||
# 定义重启退出码
|
||||
RESTART_EXIT_CODE = 42
|
||||
|
||||
print("-----------------------------------------")
|
||||
print("\n\n\n\n\n")
|
||||
print("警告:Dev进入不稳定开发状态,任何插件与WebUI均可能无法正常工作!")
|
||||
print("\n\n\n\n\n")
|
||||
print("-----------------------------------------")
|
||||
|
||||
def run_runner_process():
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -26,11 +26,11 @@ version 0.3.0 - 2026-01-11
|
|||
- [x] 其他上面两个依赖的函数已经合并到这两个函数中
|
||||
|
||||
### ExpressionConfig
|
||||
- [x] 迁移了原来在`ExpressionConfig`中的方法到一个单独的临时类`TempMethodsGroupGenerator`中
|
||||
- [x] 迁移了原来在`ExpressionConfig`中的方法到一个单独的临时类`TempMethodsExpression`中
|
||||
- [x] get_expression_config_for_chat
|
||||
- [x] 其他上面依赖的函数已经合并到这个函数中
|
||||
|
||||
### ModelConfig
|
||||
- [x] 迁移了原来在`ModelConfig`中的方法到一个单独的临时类`TempMethodsModelConfig`中
|
||||
- [ ] get_model_info
|
||||
- [ ] get_provider
|
||||
- [x] 迁移了原来在`ModelConfig`中的方法到一个单独的临时类`TempMethodsLLMUtils`中
|
||||
- [x] get_model_info
|
||||
- [x] get_provider
|
||||
|
|
@ -11,6 +11,7 @@ from src.common.database.database_model import Expression
|
|||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.bw_learner.learner_utils import weighted_sample
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
|
@ -59,7 +60,7 @@ class ExpressionSelector:
|
|||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from src.common.logger import get_logger
|
|||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
|
||||
|
|
@ -38,7 +39,7 @@ class MessageRecorder:
|
|||
"""初始化提取参数"""
|
||||
# 获取 expression 配置
|
||||
_, self.enable_expression_learning, self.enable_jargon_learning = (
|
||||
global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
TempMethodsExpression.get_expression_config_for_chat(self.chat_id)
|
||||
)
|
||||
self.min_messages_for_extraction = 30
|
||||
self.min_extraction_interval = 60
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
|
|||
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
|
||||
from src.bw_learner.jargon_explainer import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
init_lpmm_prompt()
|
||||
init_replyer_prompt()
|
||||
|
|
@ -358,7 +359,7 @@ class DefaultReplyer:
|
|||
str: 表达习惯信息字符串
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
if not use_expression:
|
||||
return "", []
|
||||
style_habits = []
|
||||
|
|
@ -1254,66 +1255,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
|||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
class TempMethodsGroupGenerator:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
if not global_config.expression.learning_list:
|
||||
return True, True, True
|
||||
|
||||
if chat_stream_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue # 这是全局的
|
||||
stream_id = TempMethodsGroupGenerator._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
(config_item.rule_type == "group"),
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id == chat_stream_id:
|
||||
continue
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
|
||||
return True, True, True
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(
|
||||
platform: str,
|
||||
id_str: str,
|
||||
is_group: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
id_str: 用户或群组的原始ID字符串
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
|||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
|
|
@ -260,7 +261,7 @@ class PrivateReplyer:
|
|||
str: 表达习惯信息字符串
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
if not use_expression:
|
||||
return "", []
|
||||
style_habits = []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from typing import Optional
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
class TempMethodsExpression:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
if not global_config.expression.learning_list:
|
||||
return True, True, True
|
||||
|
||||
if chat_stream_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue # 这是全局的
|
||||
stream_id = TempMethodsExpression._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
(config_item.rule_type == "group"),
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id == chat_stream_id:
|
||||
continue
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
|
||||
return True, True, True
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(
|
||||
platform: str,
|
||||
id_str: str,
|
||||
is_group: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
id_str: 用户或群组的原始ID字符串
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
return None
|
||||
|
|
@ -162,44 +162,6 @@ class ModelConfig(ConfigBase):
|
|||
return super().model_post_init(context)
|
||||
|
||||
|
||||
def get_model_info_by_name(model_config: ModelConfig, model_name: str) -> ModelInfo:
|
||||
"""根据模型名称获取模型信息
|
||||
|
||||
Args:
|
||||
model_config: ModelConfig实例
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
ModelInfo: 模型信息
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到指定模型
|
||||
"""
|
||||
for model in model_config.models:
|
||||
if model.name == model_name:
|
||||
return model
|
||||
raise ValueError(f"未找到名为 '{model_name}' 的模型")
|
||||
|
||||
|
||||
def get_provider_by_name(model_config: ModelConfig, provider_name: str) -> APIProvider:
|
||||
"""根据提供商名称获取提供商信息
|
||||
|
||||
Args:
|
||||
model_config: ModelConfig实例
|
||||
provider_name: 提供商名称
|
||||
|
||||
Returns:
|
||||
APIProvider: API提供商信息
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到指定提供商
|
||||
"""
|
||||
for provider in model_config.api_providers:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""总配置管理类"""
|
||||
|
||||
|
|
@ -321,4 +283,4 @@ def write_config_to_file(
|
|||
config_manager = ConfigManager()
|
||||
config_manager.initialize()
|
||||
global_config = config_manager.get_global_config()
|
||||
model_config = config_manager.get_model_config()
|
||||
model_config = config_manager.get_model_config()
|
||||
|
|
|
|||
|
|
@ -558,7 +558,7 @@ async def start_dream_scheduler(
|
|||
|
||||
start_ts = time.time()
|
||||
# 检查当前时间是否在允许做梦的时间段内
|
||||
if not global_config.dream.is_in_dream_time():
|
||||
if not TempMethodsDream.is_in_dream_time():
|
||||
logger.debug("[dream] 当前时间不在允许做梦的时间段内,跳过本次执行")
|
||||
else:
|
||||
try:
|
||||
|
|
@ -577,3 +577,34 @@ async def start_dream_scheduler(
|
|||
|
||||
# 初始化提示词
|
||||
init_dream_prompts()
|
||||
|
||||
|
||||
class TempMethodsDream:
|
||||
@staticmethod
|
||||
def is_in_dream_time() -> bool:
|
||||
if not global_config.dream.dream_time_ranges:
|
||||
return True
|
||||
now_min = time.localtime()
|
||||
now_total_min = now_min.tm_hour * 60 + now_min.tm_min
|
||||
for time_range in global_config.dream.dream_time_ranges:
|
||||
if parsed := TempMethodsDream._parse_range(time_range):
|
||||
start_min, end_min = parsed
|
||||
if TempMethodsDream._in_range(now_total_min, start_min, end_min):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _in_range(now_min, start_min, end_min) -> bool:
|
||||
if start_min <= end_min:
|
||||
return start_min <= now_min <= end_min
|
||||
return now_min >= start_min or now_min <= end_min
|
||||
|
||||
@staticmethod
|
||||
def _parse_range(range_str: str) -> Optional[Tuple[int, int]]:
|
||||
try:
|
||||
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
||||
sh, sm = [int(x) for x in start_str.split(":")]
|
||||
eh, em = [int(x) for x in end_str.split(":")]
|
||||
return sh * 60 + sm, eh * 60 + em
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set
|
|||
import traceback
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config, get_model_info_by_name, get_provider_by_name
|
||||
from src.config.config import model_config
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
|
|
@ -278,7 +278,7 @@ class LLMRequest:
|
|||
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
|
||||
|
||||
strategy = self.model_for_task.selection_strategy.lower()
|
||||
|
||||
|
||||
if strategy == "random":
|
||||
# 随机选择策略
|
||||
selected_model_name = random.choice(list(available_models.keys()))
|
||||
|
|
@ -295,9 +295,9 @@ class LLMRequest:
|
|||
available_models,
|
||||
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
|
||||
)
|
||||
|
||||
model_info = get_model_info_by_name(model_config, selected_model_name)
|
||||
api_provider = get_provider_by_name(model_config, model_info.api_provider)
|
||||
|
||||
model_info = TempMethodsLLMUtils.get_model_info_by_name(selected_model_name)
|
||||
api_provider = TempMethodsLLMUtils.get_provider_by_name(model_info.api_provider)
|
||||
force_new_client = self.request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
logger.debug(f"选择请求模型: {model_info.name} (策略: {strategy})")
|
||||
|
|
@ -456,7 +456,9 @@ class LLMRequest:
|
|||
)
|
||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
|
||||
|
||||
raise ModelAttemptFailed(f"任务 '{self.request_type or '未知任务'}' 的模型 '{model_info.name}' 未被尝试,因为重试次数已配置为0或更少。")
|
||||
raise ModelAttemptFailed(
|
||||
f"任务 '{self.request_type or '未知任务'}' 的模型 '{model_info.name}' 未被尝试,因为重试次数已配置为0或更少。"
|
||||
)
|
||||
|
||||
async def _execute_request(
|
||||
self,
|
||||
|
|
@ -576,3 +578,43 @@ class LLMRequest:
|
|||
original_error_msg = str(e.__cause__)
|
||||
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||
return ""
|
||||
|
||||
|
||||
class TempMethodsLLMUtils:
|
||||
@staticmethod
|
||||
def get_model_info_by_name(model_name: str) -> ModelInfo:
|
||||
"""根据模型名称获取模型信息
|
||||
|
||||
Args:
|
||||
model_config: ModelConfig实例
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
ModelInfo: 模型信息
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到指定模型
|
||||
"""
|
||||
for model in model_config.models:
|
||||
if model.name == model_name:
|
||||
return model
|
||||
raise ValueError(f"未找到名为 '{model_name}' 的模型")
|
||||
|
||||
@staticmethod
|
||||
def get_provider_by_name(provider_name: str) -> APIProvider:
|
||||
"""根据提供商名称获取提供商信息
|
||||
|
||||
Args:
|
||||
model_config: ModelConfig实例
|
||||
provider_name: 提供商名称
|
||||
|
||||
Returns:
|
||||
APIProvider: API提供商信息
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到指定提供商
|
||||
"""
|
||||
for provider in model_config.api_providers:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
||||
|
|
|
|||
Loading…
Reference in New Issue