修bug,改版本号
parent
1196909521
commit
5c57ba9c85
|
|
@ -272,4 +272,6 @@ $RECYCLE.BIN/
|
|||
config.toml
|
||||
config.toml.back
|
||||
test
|
||||
data/NapcatAdapter.db
|
||||
data/NapcatAdapter.db
|
||||
data/NapcatAdapter.db-shm
|
||||
data/NapcatAdapter.db-wal
|
||||
6
main.py
6
main.py
|
|
@ -16,9 +16,9 @@ message_queue = asyncio.Queue()
|
|||
|
||||
|
||||
async def message_recv(server_connection: Server.ServerConnection):
|
||||
message_handler.set_server_connection(server_connection)
|
||||
notice_handler.set_server_connection(server_connection)
|
||||
send_handler.set_server_connection(server_connection)
|
||||
await message_handler.set_server_connection(server_connection)
|
||||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||
await send_handler.set_server_connection(server_connection)
|
||||
async for raw_message in server_connection:
|
||||
logger.debug(
|
||||
f"{raw_message[:100]}..."
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "MaiBotNapcatAdapter"
|
||||
version = "0.2.6"
|
||||
version = "0.3.0"
|
||||
description = "A MaiBot adapter for Napcat"
|
||||
|
||||
[tool.ruff]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
from src.logger import logger
|
||||
|
|
@ -13,7 +14,18 @@ from src.logger import logger
|
|||
"""
|
||||
|
||||
|
||||
class BanUser(SQLModel, table=True):
|
||||
@dataclass
|
||||
class BanUser:
|
||||
"""
|
||||
程序处理使用的实例
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
group_id: int
|
||||
lift_time: Optional[int] = Field(default=-1)
|
||||
|
||||
|
||||
class DB_BanUser(SQLModel, table=True):
|
||||
"""
|
||||
表示数据库中的用户禁言记录。
|
||||
使用双重主键
|
||||
|
|
@ -24,7 +36,7 @@ class BanUser(SQLModel, table=True):
|
|||
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
||||
|
||||
|
||||
def is_identical(self, obj1: BanUser, obj2: BanUser) -> bool:
|
||||
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
||||
"""
|
||||
检查两个 BanUser 对象是否相同。
|
||||
"""
|
||||
|
|
@ -51,15 +63,16 @@ class DatabaseManager:
|
|||
logger.success("数据库和表已创建或已存在")
|
||||
|
||||
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
||||
# sourcery skip: class-extract-method
|
||||
"""
|
||||
更新禁言列表到数据库。
|
||||
支持在不存在时创建新记录,对于多余的项目自动删除。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
all_records = session.exec(select(BanUser)).all()
|
||||
all_records = session.exec(select(DB_BanUser)).all()
|
||||
for ban_user in ban_list:
|
||||
statement = select(BanUser).where(
|
||||
BanUser.user_id == ban_user.user_id, BanUser.group_id == ban_user.group_id
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
||||
)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.lift_time == ban_user.lift_time:
|
||||
|
|
@ -71,13 +84,24 @@ class DatabaseManager:
|
|||
logger.debug(f"更新禁言记录: {existing_record}")
|
||||
else:
|
||||
# 创建新记录
|
||||
session.add(ban_user)
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_user}")
|
||||
# 删除不在 ban_list 中的记录
|
||||
for record in all_records:
|
||||
for db_record in all_records:
|
||||
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
||||
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
||||
session.delete(record)
|
||||
logger.debug(f"删除禁言记录: {record}")
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
||||
)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
session.commit()
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: {ban_record}")
|
||||
|
||||
session.commit()
|
||||
logger.info("禁言记录已更新")
|
||||
|
|
@ -87,8 +111,9 @@ class DatabaseManager:
|
|||
读取所有禁言记录。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(BanUser)
|
||||
return session.exec(statement).all()
|
||||
statement = select(DB_BanUser)
|
||||
records = session.exec(statement).all()
|
||||
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
||||
|
||||
def create_ban_record(self, ban_record: BanUser) -> None:
|
||||
"""
|
||||
|
|
@ -97,7 +122,10 @@ class DatabaseManager:
|
|||
其同时还是简化版的更新方式。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
session.add(ban_record)
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
logger.debug(f"创建/更新禁言记录: {ban_record}")
|
||||
|
||||
|
|
@ -109,7 +137,7 @@ class DatabaseManager:
|
|||
user_id = ban_record.user_id
|
||||
group_id = ban_record.group_id
|
||||
with Session(self.engine) as session:
|
||||
statement = select(BanUser).where(BanUser.user_id == user_id, BanUser.group_id == group_id)
|
||||
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -36,14 +36,14 @@ class MessageHandler:
|
|||
self.server_connection: Server.ServerConnection = None
|
||||
self.bot_id_list: Dict[int, bool] = {}
|
||||
|
||||
def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
async def check_allow_to_chat(
|
||||
self,
|
||||
user_id: int,
|
||||
group_id: Optional[int],
|
||||
group_id: Optional[int] = None,
|
||||
ignore_bot: Optional[bool] = False,
|
||||
ignore_global_list: Optional[bool] = False,
|
||||
) -> bool:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class MessageSending:
|
|||
try:
|
||||
send_status = await self.maibot_router.send_message(message_base)
|
||||
if not send_status:
|
||||
raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常")
|
||||
raise RuntimeError("可能是路由未正确配置或连接异常")
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {str(e)}")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ class NoticeHandler:
|
|||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
while self.server_connection.state != Server.State.OPEN:
|
||||
await asyncio.sleep(0.5)
|
||||
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
|
||||
|
||||
asyncio.create_task(self.auto_lift_detect())
|
||||
|
|
@ -59,7 +61,7 @@ class NoticeHandler:
|
|||
self.banned_list.append(ban_record)
|
||||
db_manager.create_ban_record(ban_record) # 添加到数据库
|
||||
|
||||
def _lift_operation(self, group_id: int, user_id: Optional[int]) -> None:
|
||||
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
从self.lifted_group_list中移除已经解除全体禁言的群
|
||||
"""
|
||||
|
|
@ -77,12 +79,9 @@ class NoticeHandler:
|
|||
group_id = raw_message.get("group_id")
|
||||
user_id = raw_message.get("user_id")
|
||||
|
||||
# if not await self.check_allow_to_chat(user_id, group_id):
|
||||
# logger.warning("notice消息被丢弃")
|
||||
# return None
|
||||
|
||||
handled_message: Seg = None
|
||||
user_info: UserInfo = None
|
||||
system_notice: bool = False
|
||||
|
||||
match notice_type:
|
||||
case NoticeType.friend_recall:
|
||||
|
|
@ -110,15 +109,17 @@ class NoticeHandler:
|
|||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.GroupBan.ban:
|
||||
if await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理群禁言")
|
||||
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case NoticeType.GroupBan.lift_ban:
|
||||
if await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理解除群禁言")
|
||||
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case _:
|
||||
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
|
||||
case _:
|
||||
|
|
@ -158,8 +159,11 @@ class NoticeHandler:
|
|||
raw_message=json.dumps(raw_message),
|
||||
)
|
||||
|
||||
logger.info("发送到Maibot处理通知信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
if system_notice:
|
||||
await self.put_notice(message_base)
|
||||
else:
|
||||
logger.info("发送到Maibot处理通知信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_poke_notify(self, raw_message: dict, group_id: int, user_id: int) -> Tuple[Seg | None, UserInfo]:
|
||||
self_info: dict = await get_self_info(self.server_connection)
|
||||
|
|
@ -355,6 +359,15 @@ class NoticeHandler:
|
|||
)
|
||||
return seg_data, operator_info
|
||||
|
||||
async def put_notice(self, message_base: MessageBase) -> None:
|
||||
"""
|
||||
将处理后的通知消息放入通知队列
|
||||
"""
|
||||
if notice_queue.full() or unsuccessful_notice_queue.full():
|
||||
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
|
||||
else:
|
||||
await notice_queue.put(message_base)
|
||||
|
||||
async def handle_natural_lift(self) -> None:
|
||||
while True:
|
||||
if len(self.lifted_list) != 0:
|
||||
|
|
@ -402,11 +415,8 @@ class NoticeHandler:
|
|||
}
|
||||
),
|
||||
)
|
||||
if notice_queue.full() or unsuccessful_notice_queue.full():
|
||||
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
|
||||
else:
|
||||
await notice_queue.put(message_base)
|
||||
|
||||
await self.put_notice(message_base)
|
||||
await asyncio.sleep(0.5) # 确保队列处理间隔
|
||||
else:
|
||||
await asyncio.sleep(5) # 每5秒检查一次
|
||||
|
|
@ -449,6 +459,9 @@ class NoticeHandler:
|
|||
|
||||
async def auto_lift_detect(self) -> None:
|
||||
while True:
|
||||
if len(self.banned_list) == 0:
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
for ban_record in self.banned_list:
|
||||
if ban_record.user_id == 0 or ban_record.lift_time == -1:
|
||||
continue
|
||||
|
|
@ -457,7 +470,7 @@ class NoticeHandler:
|
|||
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
|
||||
self.lifted_list.append(ban_record)
|
||||
self.banned_list.remove(ban_record)
|
||||
asyncio.sleep(5)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def send_notice(self) -> None:
|
||||
"""
|
||||
|
|
@ -475,7 +488,7 @@ class NoticeHandler:
|
|||
except Exception as e:
|
||||
logger.error(f"发送通知消息失败: {str(e)}")
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
asyncio.sleep(0.2)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
to_be_send: MessageBase = await notice_queue.get()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class SendHandler:
|
|||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
|
||||
def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue