mirror of https://github.com/Mai-with-u/MaiBot.git
逐步适配新的config
parent
c8b4366501
commit
77725ba9d8
|
|
@ -12,3 +12,9 @@ version 0.3.0 - 2026-01-11
|
||||||
- [x] 禁止使用 `Union` 类型
|
- [x] 禁止使用 `Union` 类型
|
||||||
- [x] 禁止使用`tuple`类型,使用嵌套`dataclass`替代
|
- [x] 禁止使用`tuple`类型,使用嵌套`dataclass`替代
|
||||||
- [x] 复杂类型使用嵌套配置类实现
|
- [x] 复杂类型使用嵌套配置类实现
|
||||||
|
- [x] 配置类中禁止使用除了`model_post_init`的方法
|
||||||
|
- [x] 取代了部分与标准函数混淆的命名
|
||||||
|
- [x] `id` -> `item_id`
|
||||||
|
|
||||||
|
### BotConfig 设计
|
||||||
|
- [ ] 精简了配置项,现在只有Nickname和Alias Name了(预期将判断提及移到Adapter端)
|
||||||
|
|
|
||||||
|
|
@ -757,7 +757,7 @@ class MCPToolProxy(BaseTool):
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""调用 LLM 进行后处理"""
|
"""调用 LLM 进行后处理"""
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.config.api_ada_configs import TaskConfig
|
from src.config.model_configs import TaskConfig
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
model_name = settings.get("post_process_model", "")
|
model_name = settings.get("post_process_model", "")
|
||||||
|
|
|
||||||
|
|
@ -216,14 +216,14 @@ class HeartFChatting:
|
||||||
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply:
|
||||||
mentioned_message = message
|
mentioned_message = message
|
||||||
|
|
||||||
# logger.info(f"{self.log_prefix} 当前talk_value: {global_config.chat.get_talk_value(self.stream_id)}")
|
# logger.info(f"{self.log_prefix} 当前talk_value: {TempMethods.get_talk_value(self.stream_id)}")
|
||||||
|
|
||||||
# *控制频率用
|
# *控制频率用
|
||||||
if mentioned_message:
|
if mentioned_message:
|
||||||
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
|
await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message)
|
||||||
elif (
|
elif (
|
||||||
random.random()
|
random.random()
|
||||||
< global_config.chat.get_talk_value(self.stream_id)
|
< TempMethodsHFC.get_talk_value(self.stream_id)
|
||||||
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
|
* frequency_control_manager.get_or_create_frequency_control(self.stream_id).get_talk_frequency_adjust()
|
||||||
):
|
):
|
||||||
await self._observe(recent_messages_list=recent_messages_list)
|
await self._observe(recent_messages_list=recent_messages_list)
|
||||||
|
|
@ -325,7 +325,7 @@ class HeartFChatting:
|
||||||
|
|
||||||
cycle_timers, thinking_id = self.start_cycle()
|
cycle_timers, thinking_id = self.start_cycle()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
|
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {TempMethodsHFC.get_talk_value(self.stream_id)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 第一步:动作检查
|
# 第一步:动作检查
|
||||||
|
|
@ -547,7 +547,9 @@ class HeartFChatting:
|
||||||
)
|
)
|
||||||
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
|
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
|
||||||
if need_reply:
|
if need_reply:
|
||||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒"
|
||||||
|
)
|
||||||
|
|
||||||
reply_text = ""
|
reply_text = ""
|
||||||
first_replied = False
|
first_replied = False
|
||||||
|
|
@ -748,3 +750,64 @@ class HeartFChatting:
|
||||||
"loop_info": None,
|
"loop_info": None,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TempMethodsHFC:
|
||||||
|
@staticmethod
|
||||||
|
def get_talk_value(chat_id: Optional[str]) -> float:
|
||||||
|
result = global_config.chat.talk_value or 0.0000001
|
||||||
|
if not global_config.chat.enable_talk_value_rules or not global_config.chat.talk_value_rules:
|
||||||
|
return result
|
||||||
|
import time
|
||||||
|
|
||||||
|
local_time = time.localtime()
|
||||||
|
now_min = local_time.tm_hour * 60 + local_time.tm_min
|
||||||
|
# 先处理特定规则
|
||||||
|
if chat_id:
|
||||||
|
for rule in global_config.chat.talk_value_rules:
|
||||||
|
if not rule.platform and not rule.item_id:
|
||||||
|
continue # 一起留空表示全局,跳过
|
||||||
|
is_group = rule.rule_type == "group"
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
|
stream_id = get_chat_manager().get_stream_id(rule.platform, str(rule.item_id), is_group)
|
||||||
|
if stream_id != chat_id:
|
||||||
|
continue
|
||||||
|
parsed_range = TempMethodsHFC._parse_range(rule.time)
|
||||||
|
if not parsed_range:
|
||||||
|
continue
|
||||||
|
start_min, end_min = parsed_range
|
||||||
|
in_range: bool = False
|
||||||
|
if start_min <= end_min:
|
||||||
|
in_range = start_min <= now_min <= end_min
|
||||||
|
else:
|
||||||
|
in_range = now_min >= start_min or now_min <= end_min
|
||||||
|
if in_range:
|
||||||
|
return rule.value or 0.0
|
||||||
|
# 再处理全局规则
|
||||||
|
for rule in global_config.chat.talk_value_rules:
|
||||||
|
if rule.platform or rule.item_id:
|
||||||
|
continue # 有指定表示特定,跳过
|
||||||
|
parsed_range = TempMethodsHFC._parse_range(rule.time)
|
||||||
|
if not parsed_range:
|
||||||
|
continue
|
||||||
|
start_min, end_min = parsed_range
|
||||||
|
in_range: bool = False
|
||||||
|
if start_min <= end_min:
|
||||||
|
in_range = start_min <= now_min <= end_min
|
||||||
|
else:
|
||||||
|
in_range = now_min >= start_min or now_min <= end_min
|
||||||
|
if in_range:
|
||||||
|
return rule.value or 0.0000001
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_range(range_str: str) -> Optional[tuple[int, int]]:
|
||||||
|
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
|
||||||
|
try:
|
||||||
|
start_str, end_str = [s.strip() for s in range_str.split("-")]
|
||||||
|
sh, sm = [int(x) for x in start_str.split(":")]
|
||||||
|
eh, em = [int(x) for x in end_str.split(":")]
|
||||||
|
return sh * 60 + sm, eh * 60 + em
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -282,4 +282,5 @@ def write_config_to_file(
|
||||||
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
|
# generate_new_config_file(Config, BOT_CONFIG_PATH, CONFIG_VERSION)
|
||||||
config_manager = ConfigManager()
|
config_manager = ConfigManager()
|
||||||
config_manager.initialize()
|
config_manager.initialize()
|
||||||
# global_config = config_manager.get_global_config()
|
global_config = config_manager.get_global_config()
|
||||||
|
model_config = config_manager.get_model_config()
|
||||||
|
|
@ -18,6 +18,17 @@ class ExampleConfig(ConfigBase):
|
||||||
|
|
||||||
class BotConfig(ConfigBase):
|
class BotConfig(ConfigBase):
|
||||||
"""机器人配置类"""
|
"""机器人配置类"""
|
||||||
|
platform: str = ""
|
||||||
|
"""平台"""
|
||||||
|
|
||||||
|
qq_account: int = 0
|
||||||
|
"""QQ账号"""
|
||||||
|
|
||||||
|
platforms: list[str] = Field(default_factory=lambda: [])
|
||||||
|
"""其他平台"""
|
||||||
|
|
||||||
|
nickname: str = "麦麦"
|
||||||
|
"""机器人昵称"""
|
||||||
|
|
||||||
alias_names: list[str] = Field(default_factory=lambda: [])
|
alias_names: list[str] = Field(default_factory=lambda: [])
|
||||||
"""别名列表"""
|
"""别名列表"""
|
||||||
|
|
@ -70,10 +81,10 @@ class RelationshipConfig(ConfigBase):
|
||||||
|
|
||||||
class TalkRulesItem(ConfigBase):
|
class TalkRulesItem(ConfigBase):
|
||||||
platform: str = ""
|
platform: str = ""
|
||||||
"""平台,留空表示全局"""
|
"""平台,与ID一起留空表示全局"""
|
||||||
|
|
||||||
id: str = ""
|
item_id: str = ""
|
||||||
"""用户ID"""
|
"""用户ID,与平台一起留空表示全局"""
|
||||||
|
|
||||||
rule_type: Literal["group", "private"] = "group"
|
rule_type: Literal["group", "private"] = "group"
|
||||||
"""聊天流类型,group(群聊)或private(私聊)"""
|
"""聊天流类型,group(群聊)或private(私聊)"""
|
||||||
|
|
@ -119,8 +130,8 @@ class ChatConfig(ConfigBase):
|
||||||
|
|
||||||
talk_value_rules: list[TalkRulesItem] = Field(
|
talk_value_rules: list[TalkRulesItem] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
TalkRulesItem(platform="", id="", rule_type="group", time="00:00-08:59", value=0.8),
|
TalkRulesItem(platform="", item_id="", rule_type="group", time="00:00-08:59", value=0.8),
|
||||||
TalkRulesItem(platform="", id="", rule_type="group", time="09:00-18:59", value=1.0),
|
TalkRulesItem(platform="", item_id="", rule_type="group", time="09:00-18:59", value=1.0),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
@ -148,10 +159,10 @@ class MessageReceiveConfig(ConfigBase):
|
||||||
|
|
||||||
class TargetItem(ConfigBase):
|
class TargetItem(ConfigBase):
|
||||||
platform: str = ""
|
platform: str = ""
|
||||||
"""平台,留空表示全局"""
|
"""平台,与ID一起留空表示全局"""
|
||||||
|
|
||||||
id: str = ""
|
item_id: str = ""
|
||||||
"""用户ID"""
|
"""用户ID,与平台一起留空表示全局"""
|
||||||
|
|
||||||
rule_type: Literal["group", "private"] = "group"
|
rule_type: Literal["group", "private"] = "group"
|
||||||
"""聊天流类型,group(群聊)或private(私聊)"""
|
"""聊天流类型,group(群聊)或private(私聊)"""
|
||||||
|
|
@ -218,10 +229,10 @@ class MemoryConfig(ConfigBase):
|
||||||
|
|
||||||
class LearningItem(ConfigBase):
|
class LearningItem(ConfigBase):
|
||||||
platform: str = ""
|
platform: str = ""
|
||||||
"""平台,留空表示全局"""
|
"""平台,与ID一起留空表示全局"""
|
||||||
|
|
||||||
id: str = ""
|
item_id: str = ""
|
||||||
"""用户ID"""
|
"""用户ID,与平台一起留空表示全局"""
|
||||||
|
|
||||||
rule_type: Literal["group", "private"] = "group"
|
rule_type: Literal["group", "private"] = "group"
|
||||||
"""聊天流类型,group(群聊)或private(私聊)"""
|
"""聊天流类型,group(群聊)或private(私聊)"""
|
||||||
|
|
@ -250,7 +261,7 @@ class ExpressionConfig(ConfigBase):
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
LearningItem(
|
LearningItem(
|
||||||
platform="",
|
platform="",
|
||||||
id="",
|
item_id="",
|
||||||
rule_type="group",
|
rule_type="group",
|
||||||
use_expression=True,
|
use_expression=True,
|
||||||
enable_learning=True,
|
enable_learning=True,
|
||||||
|
|
@ -468,7 +479,7 @@ class ExtraPromptItem(ConfigBase):
|
||||||
platform: str = ""
|
platform: str = ""
|
||||||
"""平台,留空无效"""
|
"""平台,留空无效"""
|
||||||
|
|
||||||
id: str = ""
|
item_id: str = ""
|
||||||
"""用户ID,留空无效"""
|
"""用户ID,留空无效"""
|
||||||
|
|
||||||
rule_type: Literal["group", "private"] = "group"
|
rule_type: Literal["group", "private"] = "group"
|
||||||
|
|
@ -478,7 +489,7 @@ class ExtraPromptItem(ConfigBase):
|
||||||
"""额外的prompt内容"""
|
"""额外的prompt内容"""
|
||||||
|
|
||||||
def model_post_init(self, context: Optional[dict] = None) -> None:
|
def model_post_init(self, context: Optional[dict] = None) -> None:
|
||||||
if not self.platform or not self.id or not self.prompt:
|
if not self.platform or not self.item_id or not self.prompt:
|
||||||
raise ValueError("ExtraPromptItem 中 platform, id 和 prompt 不能为空")
|
raise ValueError("ExtraPromptItem 中 platform, id 和 prompt 不能为空")
|
||||||
return super().model_post_init(context)
|
return super().model_post_init(context)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Any, Optional
|
from typing import Callable, Any, Optional
|
||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
from src.config.model_configs import ModelInfo, APIProvider
|
||||||
from ..payload_content.message import Message
|
from ..payload_content.message import Message
|
||||||
from ..payload_content.resp_format import RespFormat
|
from ..payload_content.resp_format import RespFormat
|
||||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from google.genai.errors import (
|
||||||
FunctionInvocationError,
|
FunctionInvocationError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
from src.config.model_configs import ModelInfo, APIProvider
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from openai.types.chat import (
|
||||||
)
|
)
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
from src.config.model_configs import ModelInfo, APIProvider
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||||
from ..exceptions import (
|
from ..exceptions import (
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from datetime import datetime
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||||
from src.common.database.database_model import LLMUsage
|
from src.common.database.database_model import LLMUsage
|
||||||
from src.config.api_ada_configs import ModelInfo
|
from src.config.model_configs import ModelInfo
|
||||||
from .payload_content.message import Message, MessageBuilder
|
from .payload_content.message import Message, MessageBuilder
|
||||||
from .model_client.base_client import UsageRecord
|
from .model_client.base_client import UsageRecord
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import traceback
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
|
||||||
from .payload_content.message import MessageBuilder, Message
|
from .payload_content.message import MessageBuilder, Message
|
||||||
from .payload_content.resp_format import RespFormat
|
from .payload_content.resp_format import RespFormat
|
||||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from src.llm_models.payload_content.message import Message
|
||||||
from src.llm_models.model_client.base_client import BaseClient
|
from src.llm_models.model_client.base_client import BaseClient
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.config.api_ada_configs import TaskConfig
|
from src.config.model_configs import TaskConfig
|
||||||
|
|
||||||
logger = get_logger("llm_api")
|
logger = get_logger("llm_api")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ from src.config.official_configs import (
|
||||||
DebugConfig,
|
DebugConfig,
|
||||||
VoiceConfig,
|
VoiceConfig,
|
||||||
)
|
)
|
||||||
from src.config.api_ada_configs import (
|
from src.config.model_configs import (
|
||||||
ModelTaskConfig,
|
ModelTaskConfig,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
APIProvider,
|
APIProvider,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue