恢复可用性

r-dev
UnCLAS-Prommer 2026-01-16 23:03:45 +08:00
parent 4855cbc265
commit 3a66bfeac1
No known key found for this signature in database
10 changed files with 166 additions and 118 deletions

6
bot.py
View File

@ -44,7 +44,11 @@ logger = get_logger("main")
# 定义重启退出码
RESTART_EXIT_CODE = 42
print("-----------------------------------------")
print("\n\n\n\n\n")
print("警告Dev进入不稳定开发状态任何插件与WebUI均可能无法正常工作")
print("\n\n\n\n\n")
print("-----------------------------------------")
def run_runner_process():
"""

View File

@ -26,11 +26,11 @@ version 0.3.0 - 2026-01-11
- [x] 其他上面两个依赖的函数已经合并到这两个函数中
### ExpressionConfig
- [x] 迁移了原来在`ExpressionConfig`中的方法到一个单独的临时类`TempMethodsGroupGenerator`中
- [x] 迁移了原来在`ExpressionConfig`中的方法到一个单独的临时类`TempMethodsExpression`中
- [x] get_expression_config_for_chat
- [x] 其他上面依赖的函数已经合并到这个函数中
### ModelConfig
- [x] 迁移了原来在`ModelConfig`中的方法到一个单独的临时类`TempMethodsModelConfig`中
- [ ] get_model_info
- [ ] get_provider
- [x] 迁移了原来在`ModelConfig`中的方法到一个单独的临时类`TempMethodsLLMUtils`中
- [x] get_model_info
- [x] get_provider

View File

@ -11,6 +11,7 @@ from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import weighted_sample
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector")
@ -59,7 +60,7 @@ class ExpressionSelector:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")

View File

@ -5,6 +5,7 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression
from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.jargon_miner import miner_manager
@ -38,7 +39,7 @@ class MessageRecorder:
"""初始化提取参数"""
# 获取 expression 配置
_, self.enable_expression_learning, self.enable_jargon_learning = (
global_config.expression.get_expression_config_for_chat(self.chat_id)
TempMethodsExpression.get_expression_config_for_chat(self.chat_id)
)
self.min_messages_for_extraction = 30
self.min_extraction_interval = 60

View File

@ -37,6 +37,7 @@ from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
from src.bw_learner.jargon_explainer import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.chat.utils.common_utils import TempMethodsExpression
init_lpmm_prompt()
init_replyer_prompt()
@ -358,7 +359,7 @@ class DefaultReplyer:
str: 表达习惯信息字符串
"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
if not use_expression:
return "", []
style_habits = []
@ -1254,66 +1255,3 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
pool.pop(idx)
break
return selected
class TempMethodsGroupGenerator:
"""用于临时存放一些方法的类"""
@staticmethod
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
"""
根据聊天流ID获取表达配置
Args:
chat_stream_id: 聊天流ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
"""
if not global_config.expression.learning_list:
return True, True, True
if chat_stream_id:
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
continue # 这是全局的
stream_id = TempMethodsGroupGenerator._get_stream_id(
config_item.platform,
str(config_item.item_id),
(config_item.rule_type == "group"),
)
if stream_id is None:
continue
if stream_id == chat_stream_id:
continue
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
return True, True, True
@staticmethod
def _get_stream_id(
platform: str,
id_str: str,
is_group: bool = False,
) -> Optional[str]:
"""
根据平台ID字符串和是否为群聊生成聊天流ID
Args:
platform: 平台名称
id_str: 用户或群组的原始ID字符串
is_group: 是否为群聊
Returns:
str: 生成的聊天流ID哈希值
"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
except Exception as e:
logger.error(f"生成聊天流ID失败: {e}")
return None

View File

@ -18,6 +18,7 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.common_utils import TempMethodsExpression
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
@ -260,7 +261,7 @@ class PrivateReplyer:
str: 表达习惯信息字符串
"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
if not use_expression:
return "", []
style_habits = []

View File

@ -0,0 +1,68 @@
from typing import Optional
from src.config.config import global_config
from src.common.logger import get_logger
logger = get_logger("common_utils")
class TempMethodsExpression:
"""用于临时存放一些方法的类"""
@staticmethod
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
"""
根据聊天流ID获取表达配置
Args:
chat_stream_id: 聊天流ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
"""
if not global_config.expression.learning_list:
return True, True, True
if chat_stream_id:
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
continue # 这是全局的
stream_id = TempMethodsExpression._get_stream_id(
config_item.platform,
str(config_item.item_id),
(config_item.rule_type == "group"),
)
if stream_id is None:
continue
if stream_id == chat_stream_id:
continue
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
return True, True, True
@staticmethod
def _get_stream_id(
platform: str,
id_str: str,
is_group: bool = False,
) -> Optional[str]:
"""
根据平台ID字符串和是否为群聊生成聊天流ID
Args:
platform: 平台名称
id_str: 用户或群组的原始ID字符串
is_group: 是否为群聊
Returns:
str: 生成的聊天流ID哈希值
"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
except Exception as e:
logger.error(f"生成聊天流ID失败: {e}")
return None

View File

@ -162,44 +162,6 @@ class ModelConfig(ConfigBase):
return super().model_post_init(context)
def get_model_info_by_name(model_config: ModelConfig, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息
Args:
model_config: ModelConfig实例
model_name: 模型名称
Returns:
ModelInfo: 模型信息
Raises:
ValueError: 未找到指定模型
"""
for model in model_config.models:
if model.name == model_name:
return model
raise ValueError(f"未找到名为 '{model_name}' 的模型")
def get_provider_by_name(model_config: ModelConfig, provider_name: str) -> APIProvider:
"""根据提供商名称获取提供商信息
Args:
model_config: ModelConfig实例
provider_name: 提供商名称
Returns:
APIProvider: API提供商信息
Raises:
ValueError: 未找到指定提供商
"""
for provider in model_config.api_providers:
if provider.name == provider_name:
return provider
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
class ConfigManager:
"""总配置管理类"""
@ -321,4 +283,4 @@ def write_config_to_file(
config_manager = ConfigManager()
config_manager.initialize()
global_config = config_manager.get_global_config()
model_config = config_manager.get_model_config()
model_config = config_manager.get_model_config()

View File

@ -558,7 +558,7 @@ async def start_dream_scheduler(
start_ts = time.time()
# 检查当前时间是否在允许做梦的时间段内
if not global_config.dream.is_in_dream_time():
if not TempMethodsDream.is_in_dream_time():
logger.debug("[dream] 当前时间不在允许做梦的时间段内,跳过本次执行")
else:
try:
@ -577,3 +577,34 @@ async def start_dream_scheduler(
# 初始化提示词
init_dream_prompts()
class TempMethodsDream:
@staticmethod
def is_in_dream_time() -> bool:
if not global_config.dream.dream_time_ranges:
return True
now_min = time.localtime()
now_total_min = now_min.tm_hour * 60 + now_min.tm_min
for time_range in global_config.dream.dream_time_ranges:
if parsed := TempMethodsDream._parse_range(time_range):
start_min, end_min = parsed
if TempMethodsDream._in_range(now_total_min, start_min, end_min):
return True
return False
@staticmethod
def _in_range(now_min, start_min, end_min) -> bool:
if start_min <= end_min:
return start_min <= now_min <= end_min
return now_min >= start_min or now_min <= end_min
@staticmethod
def _parse_range(range_str: str) -> Optional[Tuple[int, int]]:
try:
start_str, end_str = [s.strip() for s in range_str.split("-")]
sh, sm = [int(x) for x in start_str.split(":")]
eh, em = [int(x) for x in end_str.split(":")]
return sh * 60 + sm, eh * 60 + em
except Exception:
return None

View File

@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set
import traceback
from src.common.logger import get_logger
from src.config.config import model_config, get_model_info_by_name, get_provider_by_name
from src.config.config import model_config
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
@ -278,7 +278,7 @@ class LLMRequest:
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
strategy = self.model_for_task.selection_strategy.lower()
if strategy == "random":
# 随机选择策略
selected_model_name = random.choice(list(available_models.keys()))
@ -295,9 +295,9 @@ class LLMRequest:
available_models,
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
)
model_info = get_model_info_by_name(model_config, selected_model_name)
api_provider = get_provider_by_name(model_config, model_info.api_provider)
model_info = TempMethodsLLMUtils.get_model_info_by_name(selected_model_name)
api_provider = TempMethodsLLMUtils.get_provider_by_name(model_info.api_provider)
force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name} (策略: {strategy})")
@ -456,7 +456,9 @@ class LLMRequest:
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
raise ModelAttemptFailed(f"任务 '{self.request_type or '未知任务'}' 的模型 '{model_info.name}' 未被尝试因为重试次数已配置为0或更少。")
raise ModelAttemptFailed(
f"任务 '{self.request_type or '未知任务'}' 的模型 '{model_info.name}' 未被尝试因为重试次数已配置为0或更少。"
)
async def _execute_request(
self,
@ -576,3 +578,43 @@ class LLMRequest:
original_error_msg = str(e.__cause__)
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
return ""
class TempMethodsLLMUtils:
@staticmethod
def get_model_info_by_name(model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息
Args:
model_config: ModelConfig实例
model_name: 模型名称
Returns:
ModelInfo: 模型信息
Raises:
ValueError: 未找到指定模型
"""
for model in model_config.models:
if model.name == model_name:
return model
raise ValueError(f"未找到名为 '{model_name}' 的模型")
@staticmethod
def get_provider_by_name(provider_name: str) -> APIProvider:
"""根据提供商名称获取提供商信息
Args:
model_config: ModelConfig实例
provider_name: 提供商名称
Returns:
APIProvider: API提供商信息
Raises:
ValueError: 未找到指定提供商
"""
for provider in model_config.api_providers:
if provider.name == provider_name:
return provider
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")