增加黑白名单机制
parent
9f381ab1a9
commit
0c94458ec1
2
main.py
2
main.py
|
|
@ -17,7 +17,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||||
logger.debug(f"{raw_message[:80]}..." if len(raw_message) > 80 else raw_message)
|
logger.debug(f"{raw_message[:80]}..." if len(raw_message) > 80 else raw_message)
|
||||||
decoded_raw_message: dict = json.loads(raw_message)
|
decoded_raw_message: dict = json.loads(raw_message)
|
||||||
post_type = decoded_raw_message.get("post_type")
|
post_type = decoded_raw_message.get("post_type")
|
||||||
if post_type in ["meta_event","message","notice"]:
|
if post_type in ["meta_event", "message", "notice"]:
|
||||||
await message_queue.put(decoded_raw_message)
|
await message_queue.put(decoded_raw_message)
|
||||||
elif post_type is None:
|
elif post_type is None:
|
||||||
await put_response(decoded_raw_message)
|
await put_response(decoded_raw_message)
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,16 @@ class Config:
|
||||||
self.config_path = os.path.join(self.root_path, "config.toml")
|
self.config_path = os.path.join(self.root_path, "config.toml")
|
||||||
|
|
||||||
def load_config(self): # sourcery skip: extract-method, move-assign
|
def load_config(self): # sourcery skip: extract-method, move-assign
|
||||||
include_configs = ["Nickname", "Napcat_Server", "MaiBot_Server", "Debug", "Voice"]
|
include_configs = ["Napcat_Server", "MaiBot_Server", "Chat", "Voice", "Debug"]
|
||||||
if os.path.exists(self.config_path):
|
if not os.path.exists(self.config_path):
|
||||||
|
logger.error("配置文件不存在!")
|
||||||
|
logger.info("正在创建配置文件...")
|
||||||
|
shutil.copy(
|
||||||
|
os.path.join(self.root_path, "template", "template_config.toml"),
|
||||||
|
os.path.join(self.root_path, "config.toml"),
|
||||||
|
)
|
||||||
|
logger.info("配置文件创建成功,请修改配置文件后重启程序。")
|
||||||
|
sys.exit(1)
|
||||||
with open(self.config_path, "rb") as f:
|
with open(self.config_path, "rb") as f:
|
||||||
try:
|
try:
|
||||||
raw_config = tomli.load(f)
|
raw_config = tomli.load(f)
|
||||||
|
|
@ -35,28 +43,29 @@ class Config:
|
||||||
if key not in raw_config:
|
if key not in raw_config:
|
||||||
logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
self.nickname = raw_config["Nickname"].get("nickname")
|
|
||||||
self.server_host = raw_config["Napcat_Server"].get("host", "localhost")
|
self.server_host = raw_config["Napcat_Server"].get("host", "localhost")
|
||||||
self.server_port = raw_config["Napcat_Server"].get("port", 8095)
|
self.server_port = raw_config["Napcat_Server"].get("port", 8095)
|
||||||
|
self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30)
|
||||||
|
|
||||||
|
self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost")
|
||||||
|
self.mai_port = raw_config["MaiBot_Server"].get("port", 8000)
|
||||||
self.platform = raw_config["MaiBot_Server"].get("platform_name")
|
self.platform = raw_config["MaiBot_Server"].get("platform_name")
|
||||||
if not self.platform:
|
if not self.platform:
|
||||||
logger.critical("请在配置文件中指定平台")
|
logger.critical("请在配置文件中指定平台")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
self.napcat_heartbeat_interval = raw_config["Napcat_Server"].get("heartbeat", 30)
|
|
||||||
self.mai_host = raw_config["MaiBot_Server"].get("host", "localhost")
|
self.list_type: str = raw_config["Chat"].get("list_type")
|
||||||
self.mai_port = raw_config["MaiBot_Server"].get("port", 8000)
|
self.group_list: list = raw_config["Chat"].get("group_list", [])
|
||||||
self.debug_level = raw_config["Debug"].get("level", "INFO")
|
self.user_list: list = raw_config["Chat"].get("user_list", [])
|
||||||
self.use_tts = raw_config["Voice"].get("use_tts", False)
|
if not self.list_type or self.list_type not in ["whitelist", "blacklist"]:
|
||||||
else:
|
logger.critical("请在配置文件中指定list_type或list_type填写错误")
|
||||||
logger.error("配置文件不存在!")
|
|
||||||
logger.info("正在创建配置文件...")
|
|
||||||
shutil.copy(
|
|
||||||
os.path.join(self.root_path, "template", "template_config.toml"),
|
|
||||||
os.path.join(self.root_path, "config.toml"),
|
|
||||||
)
|
|
||||||
logger.info("配置文件创建成功,请修改配置文件后重启程序。")
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
self.use_tts = raw_config["Voice"].get("use_tts", False)
|
||||||
|
|
||||||
|
self.debug_level = raw_config["Debug"].get("level", "INFO")
|
||||||
|
|
||||||
|
|
||||||
global_config = Config()
|
global_config = Config()
|
||||||
global_config.load_config()
|
global_config.load_config()
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,13 @@ class RecvHandler:
|
||||||
if sub_type == MessageType.Private.friend:
|
if sub_type == MessageType.Private.friend:
|
||||||
sender_info: dict = raw_message.get("sender")
|
sender_info: dict = raw_message.get("sender")
|
||||||
|
|
||||||
|
if global_config.list_type == "whitelist" and sender_info.get("user_id") not in global_config.user_list:
|
||||||
|
logger.warning("用户不在白名单中,消息被丢弃")
|
||||||
|
return None
|
||||||
|
if global_config.list_type == "blacklist" and sender_info.get("user_id") in global_config.user_list:
|
||||||
|
logger.warning("用户在黑名单中,消息被丢弃")
|
||||||
|
return None
|
||||||
|
|
||||||
# 发送者用户信息
|
# 发送者用户信息
|
||||||
user_info: UserInfo = UserInfo(
|
user_info: UserInfo = UserInfo(
|
||||||
platform=global_config.platform,
|
platform=global_config.platform,
|
||||||
|
|
@ -144,6 +151,13 @@ class RecvHandler:
|
||||||
if sub_type == MessageType.Group.normal:
|
if sub_type == MessageType.Group.normal:
|
||||||
sender_info: dict = raw_message.get("sender")
|
sender_info: dict = raw_message.get("sender")
|
||||||
|
|
||||||
|
if global_config.list_type == "whitelist" and raw_message.get("group_id") not in global_config.group_list:
|
||||||
|
logger.warning("群聊不在白名单中,消息被丢弃")
|
||||||
|
return None
|
||||||
|
if global_config.list_type == "blacklist" and raw_message.get("group_id") in global_config.group_list:
|
||||||
|
logger.warning("群聊在黑名单中,消息被丢弃")
|
||||||
|
return None
|
||||||
|
|
||||||
# 发送者用户信息
|
# 发送者用户信息
|
||||||
user_info: UserInfo = UserInfo(
|
user_info: UserInfo = UserInfo(
|
||||||
platform=global_config.platform,
|
platform=global_config.platform,
|
||||||
|
|
@ -400,8 +414,7 @@ class RecvHandler:
|
||||||
member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id)
|
member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id)
|
||||||
if member_info:
|
if member_info:
|
||||||
return Seg(
|
return Seg(
|
||||||
type=RealMessageType.text,
|
type=RealMessageType.text, data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>"
|
||||||
data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import websockets as Server
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
|
||||||
# 白名单机制不启用
|
# 白名单机制不启用
|
||||||
from .message_queue import get_response
|
from .message_queue import get_response
|
||||||
from .logger import logger
|
from .logger import logger
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,9 @@ import io
|
||||||
class SSLAdapter(urllib3.PoolManager):
|
class SSLAdapter(urllib3.PoolManager):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
context = ssl.create_default_context()
|
context = ssl.create_default_context()
|
||||||
context.set_ciphers('DEFAULT@SECLEVEL=1')
|
context.set_ciphers("DEFAULT@SECLEVEL=1")
|
||||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
kwargs['ssl_context'] = context
|
kwargs["ssl_context"] = context
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ async def get_image_base64(url: str) -> str:
|
||||||
"""获取图片/表情包的Base64"""
|
"""获取图片/表情包的Base64"""
|
||||||
http = SSLAdapter()
|
http = SSLAdapter()
|
||||||
try:
|
try:
|
||||||
response = http.request('GET', url, timeout=10)
|
response = http.request("GET", url, timeout=10)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise Exception(f"HTTP Error: {response.status}")
|
raise Exception(f"HTTP Error: {response.status}")
|
||||||
image_bytes = response.data
|
image_bytes = response.data
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,12 @@ platform_name = "qq" # 标识adapter的名称(必填)
|
||||||
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
||||||
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
||||||
|
|
||||||
[Whitelist] # 白名单功能(未启用)(未实现)
|
[Chat] # 白名单功能(未启用)
|
||||||
group_list = []
|
list_type = "whitelist" # 使用的白名单类型,可选为:whitelist, blacklist
|
||||||
private_list = []
|
# 当list_type为white时,使用白名单模式,以下两个列表的含义是:仅允许名单中的人聊天
|
||||||
enable_temp = false
|
# 当list_type为black时,使用黑名单模式,以下两个列表的含义是:禁止名单中的人聊天
|
||||||
|
group_list = [] # 群组名单
|
||||||
|
private_list = [] # 私聊名单
|
||||||
|
|
||||||
[Voice] # 发送语音设置
|
[Voice] # 发送语音设置
|
||||||
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue