增加黑白名单机制

pull/31/head
UnCLAS-Prommer 2025-05-17 16:11:12 +08:00
parent 9f381ab1a9
commit 0c94458ec1
6 changed files with 60 additions and 35 deletions

View File

@ -17,7 +17,7 @@ async def message_recv(server_connection: Server.ServerConnection):
logger.debug(f"{raw_message[:80]}..." if len(raw_message) > 80 else raw_message) logger.debug(f"{raw_message[:80]}..." if len(raw_message) > 80 else raw_message)
decoded_raw_message: dict = json.loads(raw_message) decoded_raw_message: dict = json.loads(raw_message)
post_type = decoded_raw_message.get("post_type") post_type = decoded_raw_message.get("post_type")
if post_type in ["meta_event","message","notice"]: if post_type in ["meta_event", "message", "notice"]:
await message_queue.put(decoded_raw_message) await message_queue.put(decoded_raw_message)
elif post_type is None: elif post_type is None:
await put_response(decoded_raw_message) await put_response(decoded_raw_message)

View File

@ -23,8 +23,16 @@ class Config:
self.config_path = os.path.join(self.root_path, "config.toml") self.config_path = os.path.join(self.root_path, "config.toml")
def load_config(self): # sourcery skip: extract-method, move-assign def load_config(self): # sourcery skip: extract-method, move-assign
include_configs = ["Nickname", "Napcat_Server", "MaiBot_Server", "Debug", "Voice"] include_configs = ["Napcat_Server", "MaiBot_Server", "Chat", "Voice", "Debug"]
if os.path.exists(self.config_path): if not os.path.exists(self.config_path):
logger.error("配置文件不存在!")
logger.info("正在创建配置文件...")
shutil.copy(
os.path.join(self.root_path, "template", "template_config.toml"),
os.path.join(self.root_path, "config.toml"),
)
logger.info("配置文件创建成功,请修改配置文件后重启程序。")
sys.exit(1)
with open(self.config_path, "rb") as f: with open(self.config_path, "rb") as f:
try: try:
raw_config = tomli.load(f) raw_config = tomli.load(f)
@ -35,28 +43,29 @@ class Config:
if key not in raw_config: if key not in raw_config:
logger.error(f"配置文件中缺少必需的字段: '{key}'") logger.error(f"配置文件中缺少必需的字段: '{key}'")
sys.exit(1) sys.exit(1)
self.nickname = raw_config["Nickname"].get("nickname")
self.server_host = raw_config["Napcat_Server"].get("host", "localhost") self.server_host = raw_config["Napcat_Server"].get("host", "localhost")
self.server_port = raw_config["Napcat_Server"].get("port", 8095) self.server_port = raw_config["Napcat_Server"].get("port", 8095)
self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30)
self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost")
self.mai_port = raw_config["MaiBot_Server"].get("port", 8000)
self.platform = raw_config["MaiBot_Server"].get("platform_name") self.platform = raw_config["MaiBot_Server"].get("platform_name")
if not self.platform: if not self.platform:
logger.critical("请在配置文件中指定平台") logger.critical("请在配置文件中指定平台")
sys.exit(1) sys.exit(1)
self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30)
self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost") self.list_type: str = raw_config["Chat"].get("list_type")
self.mai_port = raw_config["MaiBot_Server"].get("port", 8000) self.group_list: list = raw_config["Chat"].get("group_list", [])
self.debug_level = raw_config["Debug"].get("level", "INFO") self.user_list: list = raw_config["Chat"].get("user_list", [])
self.use_tts = raw_config["Voice"].get("use_tts", False) if not self.list_type or self.list_type not in ["whitelist", "blacklist"]:
else: logger.critical("请在配置文件中指定list_type或list_type填写错误")
logger.error("配置文件不存在!")
logger.info("正在创建配置文件...")
shutil.copy(
os.path.join(self.root_path, "template", "template_config.toml"),
os.path.join(self.root_path, "config.toml"),
)
logger.info("配置文件创建成功,请修改配置文件后重启程序。")
sys.exit(1) sys.exit(1)
self.use_tts = raw_config["Voice"].get("use_tts", False)
self.debug_level = raw_config["Debug"].get("level", "INFO")
global_config = Config() global_config = Config()
global_config.load_config() global_config.load_config()

View File

@ -88,6 +88,13 @@ class RecvHandler:
if sub_type == MessageType.Private.friend: if sub_type == MessageType.Private.friend:
sender_info: dict = raw_message.get("sender") sender_info: dict = raw_message.get("sender")
if global_config.list_type == "whitelist" and sender_info.get("user_id") not in global_config.user_list:
logger.warning("用户不在白名单中,消息被丢弃")
return None
if global_config.list_type == "blacklist" and sender_info.get("user_id") in global_config.user_list:
logger.warning("用户在黑名单中,消息被丢弃")
return None
# 发送者用户信息 # 发送者用户信息
user_info: UserInfo = UserInfo( user_info: UserInfo = UserInfo(
platform=global_config.platform, platform=global_config.platform,
@ -144,6 +151,13 @@ class RecvHandler:
if sub_type == MessageType.Group.normal: if sub_type == MessageType.Group.normal:
sender_info: dict = raw_message.get("sender") sender_info: dict = raw_message.get("sender")
if global_config.list_type == "whitelist" and raw_message.get("group_id") not in global_config.group_list:
logger.warning("群聊不在白名单中,消息被丢弃")
return None
if global_config.list_type == "blacklist" and raw_message.get("group_id") in global_config.group_list:
logger.warning("群聊在黑名单中,消息被丢弃")
return None
# 发送者用户信息 # 发送者用户信息
user_info: UserInfo = UserInfo( user_info: UserInfo = UserInfo(
platform=global_config.platform, platform=global_config.platform,
@ -400,8 +414,7 @@ class RecvHandler:
member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id) member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id)
if member_info: if member_info:
return Seg( return Seg(
type=RealMessageType.text, type=RealMessageType.text, data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>"
data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>"
) )
else: else:
return None return None

View File

@ -3,6 +3,7 @@ import websockets as Server
import uuid import uuid
from .config import global_config from .config import global_config
# 白名单机制不启用 # 白名单机制不启用
from .message_queue import get_response from .message_queue import get_response
from .logger import logger from .logger import logger

View File

@ -15,9 +15,9 @@ import io
class SSLAdapter(urllib3.PoolManager): class SSLAdapter(urllib3.PoolManager):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
context = ssl.create_default_context() context = ssl.create_default_context()
context.set_ciphers('DEFAULT@SECLEVEL=1') context.set_ciphers("DEFAULT@SECLEVEL=1")
context.minimum_version = ssl.TLSVersion.TLSv1_2 context.minimum_version = ssl.TLSVersion.TLSv1_2
kwargs['ssl_context'] = context kwargs["ssl_context"] = context
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -74,7 +74,7 @@ async def get_image_base64(url: str) -> str:
"""获取图片/表情包的Base64""" """获取图片/表情包的Base64"""
http = SSLAdapter() http = SSLAdapter()
try: try:
response = http.request('GET', url, timeout=10) response = http.request("GET", url, timeout=10)
if response.status != 200: if response.status != 200:
raise Exception(f"HTTP Error: {response.status}") raise Exception(f"HTTP Error: {response.status}")
image_bytes = response.data image_bytes = response.data

View File

@ -11,10 +11,12 @@ platform_name = "qq" # 标识adapter的名称必填
host = "localhost" # 麦麦在.env文件中设置的主机地址即HOST字段 host = "localhost" # 麦麦在.env文件中设置的主机地址即HOST字段
port = 8000 # 麦麦在.env文件中设置的端口即PORT字段 port = 8000 # 麦麦在.env文件中设置的端口即PORT字段
[Whitelist] # 白名单功能(未启用)(未实现) [Chat] # 白名单功能(未启用)
group_list = [] list_type = "whitelist" # 使用的白名单类型可选为whitelist, blacklist
private_list = [] # 当list_type为white时使用白名单模式以下两个列表的含义是仅允许名单中的人聊天
enable_temp = false # 当list_type为black时使用黑名单模式以下两个列表的含义是禁止名单中的人聊天
group_list = [] # 群组名单
private_list = [] # 私聊名单
[Voice] # 发送语音设置 [Voice] # 发送语音设置
use_tts = false # 是否使用tts语音请确保你配置了tts并有对应的adapter use_tts = false # 是否使用tts语音请确保你配置了tts并有对应的adapter