逐步适配新的config

r-dev
UnCLAS-Prommer 2026-01-15 23:51:19 +08:00
parent c8b4366501
commit 77725ba9d8
No known key found for this signature in database
12 changed files with 111 additions and 30 deletions

View File

@ -11,4 +11,10 @@ version 0.3.0 - 2026-01-11
- [x] 配置类中的所有原子项目应该只包含以下类型: `str`, `int`, `float`, `bool`, `list`, `dict`, `set`
- [x] 禁止使用 `Union` 类型
- [x] 禁止使用`tuple`类型,使用嵌套`dataclass`替代
- [x] 复杂类型使用嵌套配置类实现
- [x] 复杂类型使用嵌套配置类实现
- [x] 配置类中禁止使用除了`model_post_init`的方法
- [x] 取代了部分与标准函数混淆的命名
- [x] `id` -> `item_id`
### BotConfig 设计
- [ ] 精简了配置项现在只有Nickname和Alias Name了预期将判断提及移到Adapter端

View File

@ -757,7 +757,7 @@ class MCPToolProxy(BaseTool):
) -> Optional[str]:
"""调用 LLM 进行后处理"""
from src.config.config import model_config
from src.config.api_ada_configs import TaskConfig
from src.config.model_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest
model_name = settings.get("post_process_model", "")

View File

@ -216,14 +216,14 @@ class HeartFChatting:
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
mentioned_message = message
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
# logger.info(f"{self.log_prefix} 当前talk_value: {TempMethods.get_talk_value(self.stream_id)}")
# *控制频率用
if mentioned_message:
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
elif (
random.random()
< global_config.chat.get_talk_value(self.stream_id)
< TempMethodsHFC.get_talk_value(self.stream_id)
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
):
await self._observe(recent_messages_list=recent_messages_list)
@ -325,7 +325,7 @@ class HeartFChatting:
cycle_timers, thinking_id = self.start_cycle()
logger.info(
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {TempMethodsHFC.get_talk_value(self.stream_id)})"
)
# 第一步:动作检查
@ -547,7 +547,9 @@ class HeartFChatting:
)
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息使用引用回复或者上次回复时间超过90秒")
logger.info(
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息使用引用回复或者上次回复时间超过90秒"
)
reply_text = ""
first_replied = False
@ -665,7 +667,7 @@ class HeartFChatting:
cleaned_uw.append(s)
if cleaned_uw:
unknown_words = cleaned_uw
# 从 Planner 的 action_data 中提取 quote_message 参数
qm = action_planner_info.action_data.get("quote")
if qm is not None:
@ -676,7 +678,7 @@ class HeartFChatting:
quote_message = qm.lower() in ("true", "1", "yes")
elif isinstance(qm, (int, float)):
quote_message = bool(qm)
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
success, llm_response = await generator_api.generate_reply(
@ -748,3 +750,64 @@ class HeartFChatting:
"loop_info": None,
"error": str(e),
}
class TempMethodsHFC:
@staticmethod
def get_talk_value(chat_id: Optional[str]) -> float:
result = global_config.chat.talk_value or 0.0000001
if not global_config.chat.enable_talk_value_rules or not global_config.chat.talk_value_rules:
return result
import time
local_time = time.localtime()
now_min = local_time.tm_hour * 60 + local_time.tm_min
# 先处理特定规则
if chat_id:
for rule in global_config.chat.talk_value_rules:
if not rule.platform and not rule.item_id:
continue # 一起留空表示全局,跳过
is_group = rule.rule_type == "group"
from src.chat.message_receive.chat_stream import get_chat_manager
stream_id = get_chat_manager().get_stream_id(rule.platform, str(rule.item_id), is_group)
if stream_id != chat_id:
continue
parsed_range = TempMethodsHFC._parse_range(rule.time)
if not parsed_range:
continue
start_min, end_min = parsed_range
in_range: bool = False
if start_min <= end_min:
in_range = start_min <= now_min <= end_min
else:
in_range = now_min >= start_min or now_min <= end_min
if in_range:
return rule.value or 0.0
# 再处理全局规则
for rule in global_config.chat.talk_value_rules:
if rule.platform or rule.item_id:
continue # 有指定表示特定,跳过
parsed_range = TempMethodsHFC._parse_range(rule.time)
if not parsed_range:
continue
start_min, end_min = parsed_range
in_range: bool = False
if start_min <= end_min:
in_range = start_min <= now_min <= end_min
else:
in_range = now_min >= start_min or now_min <= end_min
if in_range:
return rule.value or 0.0000001
return result
@staticmethod
def _parse_range(range_str: str) -> Optional[tuple[int, int]]:
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
try:
start_str, end_str = [s.strip() for s in range_str.split("-")]
sh, sm = [int(x) for x in start_str.split(":")]
eh, em = [int(x) for x in end_str.split(":")]
return sh * 60 + sm, eh * 60 + em
except Exception:
return None

View File

@ -282,4 +282,5 @@ def write_config_to_file(
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
config_manager = ConfigManager()
config_manager.initialize()
# global_config = config_manager.get_global_config()
global_config = config_manager.get_global_config()
model_config = config_manager.get_model_config()

View File

@ -18,6 +18,17 @@ class ExampleConfig(ConfigBase):
class BotConfig(ConfigBase):
"""机器人配置类"""
platform: str = ""
"""平台"""
qq_account: int = 0
"""QQ账号"""
platforms: list[str] = Field(default_factory=lambda: [])
"""其他平台"""
nickname: str = "麦麦"
"""机器人昵称"""
alias_names: list[str] = Field(default_factory=lambda: [])
"""别名列表"""
@ -70,10 +81,10 @@ class RelationshipConfig(ConfigBase):
class TalkRulesItem(ConfigBase):
platform: str = ""
"""平台,留空表示全局"""
"""平台,与ID一起留空表示全局"""
id: str = ""
"""用户ID"""
item_id: str = ""
"""用户ID,与平台一起留空表示全局"""
rule_type: Literal["group", "private"] = "group"
"""聊天流类型group群聊或private私聊"""
@ -119,8 +130,8 @@ class ChatConfig(ConfigBase):
talk_value_rules: list[TalkRulesItem] = Field(
default_factory=lambda: [
TalkRulesItem(platform="", id="", rule_type="group", time="00:00-08:59", value=0.8),
TalkRulesItem(platform="", id="", rule_type="group", time="09:00-18:59", value=1.0),
TalkRulesItem(platform="", item_id="", rule_type="group", time="00:00-08:59", value=0.8),
TalkRulesItem(platform="", item_id="", rule_type="group", time="09:00-18:59", value=1.0),
]
)
"""
@ -148,10 +159,10 @@ class MessageReceiveConfig(ConfigBase):
class TargetItem(ConfigBase):
platform: str = ""
"""平台,留空表示全局"""
"""平台,与ID一起留空表示全局"""
id: str = ""
"""用户ID"""
item_id: str = ""
"""用户ID,与平台一起留空表示全局"""
rule_type: Literal["group", "private"] = "group"
"""聊天流类型group群聊或private私聊"""
@ -218,10 +229,10 @@ class MemoryConfig(ConfigBase):
class LearningItem(ConfigBase):
platform: str = ""
"""平台,留空表示全局"""
"""平台,与ID一起留空表示全局"""
id: str = ""
"""用户ID"""
item_id: str = ""
"""用户ID,与平台一起留空表示全局"""
rule_type: Literal["group", "private"] = "group"
"""聊天流类型group群聊或private私聊"""
@ -250,7 +261,7 @@ class ExpressionConfig(ConfigBase):
default_factory=lambda: [
LearningItem(
platform="",
id="",
item_id="",
rule_type="group",
use_expression=True,
enable_learning=True,
@ -468,7 +479,7 @@ class ExtraPromptItem(ConfigBase):
platform: str = ""
"""平台,留空无效"""
id: str = ""
item_id: str = ""
"""用户ID留空无效"""
rule_type: Literal["group", "private"] = "group"
@ -478,7 +489,7 @@ class ExtraPromptItem(ConfigBase):
"""额外的prompt内容"""
def model_post_init(self, context: Optional[dict] = None) -> None:
if not self.platform or not self.id or not self.prompt:
if not self.platform or not self.item_id or not self.prompt:
raise ValueError("ExtraPromptItem 中 platform, id 和 prompt 不能为空")
return super().model_post_init(context)

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.config.model_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall

View File

@ -30,7 +30,7 @@ from google.genai.errors import (
FunctionInvocationError,
)
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.config.model_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry

View File

@ -22,7 +22,7 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.config.model_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (

View File

@ -7,7 +7,7 @@ from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.config.api_ada_configs import ModelInfo
from src.config.model_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord

View File

@ -10,7 +10,7 @@ import traceback
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType

View File

@ -14,7 +14,7 @@ from src.llm_models.payload_content.message import Message
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.config.api_ada_configs import TaskConfig
from src.config.model_configs import TaskConfig
logger = get_logger("llm_api")

View File

@ -32,7 +32,7 @@ from src.config.official_configs import (
DebugConfig,
VoiceConfig,
)
from src.config.api_ada_configs import (
from src.config.model_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,