修bug,改版本号

dev-from070-to080
UnCLAS-Prommer 2025-06-28 11:44:35 +08:00
parent 1196909521
commit 5c57ba9c85
8 changed files with 80 additions and 37 deletions

4
.gitignore vendored
View File

@ -272,4 +272,6 @@ $RECYCLE.BIN/
config.toml
config.toml.back
test
data/NapcatAdapter.db
data/NapcatAdapter.db
data/NapcatAdapter.db-shm
data/NapcatAdapter.db-wal

View File

@ -16,9 +16,9 @@ message_queue = asyncio.Queue()
async def message_recv(server_connection: Server.ServerConnection):
message_handler.set_server_connection(server_connection)
notice_handler.set_server_connection(server_connection)
send_handler.set_server_connection(server_connection)
await message_handler.set_server_connection(server_connection)
asyncio.create_task(notice_handler.set_server_connection(server_connection))
await send_handler.set_server_connection(server_connection)
async for raw_message in server_connection:
logger.debug(
f"{raw_message[:100]}..."

View File

@ -1,6 +1,6 @@
[project]
name = "MaiBotNapcatAdapter"
version = "0.2.6"
version = "0.3.0"
description = "A MaiBot adapter for Napcat"
[tool.ruff]

View File

@ -1,5 +1,6 @@
import os
from typing import Optional, List
from dataclasses import dataclass
from sqlmodel import Field, Session, SQLModel, create_engine, select
from src.logger import logger
@ -13,7 +14,18 @@ from src.logger import logger
"""
class BanUser(SQLModel, table=True):
@dataclass
class BanUser:
"""
程序处理使用的实例
"""
user_id: int
group_id: int
lift_time: Optional[int] = Field(default=-1)
class DB_BanUser(SQLModel, table=True):
"""
表示数据库中的用户禁言记录
使用双重主键
@ -24,7 +36,7 @@ class BanUser(SQLModel, table=True):
lift_time: Optional[int] # 禁言解除的时间(时间戳)
def is_identical(self, obj1: BanUser, obj2: BanUser) -> bool:
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
"""
检查两个 BanUser 对象是否相同
"""
@ -51,15 +63,16 @@ class DatabaseManager:
logger.success("数据库和表已创建或已存在")
def update_ban_record(self, ban_list: List[BanUser]) -> None:
# sourcery skip: class-extract-method
"""
更新禁言列表到数据库
支持在不存在时创建新记录对于多余的项目自动删除
"""
with Session(self.engine) as session:
all_records = session.exec(select(BanUser)).all()
all_records = session.exec(select(DB_BanUser)).all()
for ban_user in ban_list:
statement = select(BanUser).where(
BanUser.user_id == ban_user.user_id, BanUser.group_id == ban_user.group_id
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
)
if existing_record := session.exec(statement).first():
if existing_record.lift_time == ban_user.lift_time:
@ -71,13 +84,24 @@ class DatabaseManager:
logger.debug(f"更新禁言记录: {existing_record}")
else:
# 创建新记录
session.add(ban_user)
db_record = DB_BanUser(
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_user}")
# 删除不在 ban_list 中的记录
for record in all_records:
for db_record in all_records:
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
if not any(is_identical(record, ban_user) for ban_user in ban_list):
session.delete(record)
logger.debug(f"删除禁言记录: {record}")
statement = select(DB_BanUser).where(
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
session.commit()
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: {ban_record}")
session.commit()
logger.info("禁言记录已更新")
@ -87,8 +111,9 @@ class DatabaseManager:
读取所有禁言记录
"""
with Session(self.engine) as session:
statement = select(BanUser)
return session.exec(statement).all()
statement = select(DB_BanUser)
records = session.exec(statement).all()
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
def create_ban_record(self, ban_record: BanUser) -> None:
"""
@ -97,7 +122,10 @@ class DatabaseManager:
其同时还是简化版的更新方式
"""
with Session(self.engine) as session:
session.add(ban_record)
db_record = DB_BanUser(
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
)
session.add(db_record)
session.commit()
logger.debug(f"创建/更新禁言记录: {ban_record}")
@ -109,7 +137,7 @@ class DatabaseManager:
user_id = ban_record.user_id
group_id = ban_record.group_id
with Session(self.engine) as session:
statement = select(BanUser).where(BanUser.user_id == user_id, BanUser.group_id == group_id)
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
session.commit()

View File

@ -36,14 +36,14 @@ class MessageHandler:
self.server_connection: Server.ServerConnection = None
self.bot_id_list: Dict[int, bool] = {}
def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
async def check_allow_to_chat(
self,
user_id: int,
group_id: Optional[int],
group_id: Optional[int] = None,
ignore_bot: Optional[bool] = False,
ignore_global_list: Optional[bool] = False,
) -> bool:

View File

@ -21,7 +21,7 @@ class MessageSending:
try:
send_status = await self.maibot_router.send_message(message_base)
if not send_status:
raise RuntimeError("发送消息失败,可能是路由未正确配置或连接异常")
raise RuntimeError("可能是路由未正确配置或连接异常")
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
logger.error("请检查与MaiBot之间的连接")

View File

@ -35,6 +35,8 @@ class NoticeHandler:
"""设置Napcat连接"""
self.server_connection = server_connection
while self.server_connection.state != Server.State.OPEN:
await asyncio.sleep(0.5)
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
asyncio.create_task(self.auto_lift_detect())
@ -59,7 +61,7 @@ class NoticeHandler:
self.banned_list.append(ban_record)
db_manager.create_ban_record(ban_record) # 添加到数据库
def _lift_operation(self, group_id: int, user_id: Optional[int]) -> None:
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
"""
从self.lifted_group_list中移除已经解除全体禁言的群
"""
@ -77,12 +79,9 @@ class NoticeHandler:
group_id = raw_message.get("group_id")
user_id = raw_message.get("user_id")
# if not await self.check_allow_to_chat(user_id, group_id):
# logger.warning("notice消息被丢弃")
# return None
handled_message: Seg = None
user_info: UserInfo = None
system_notice: bool = False
match notice_type:
case NoticeType.friend_recall:
@ -110,15 +109,17 @@ class NoticeHandler:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.GroupBan.ban:
if await message_handler.check_allow_to_chat(user_id, group_id, True, False):
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理群禁言")
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
system_notice = True
case NoticeType.GroupBan.lift_ban:
if await message_handler.check_allow_to_chat(user_id, group_id, True, False):
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理解除群禁言")
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
system_notice = True
case _:
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
case _:
@ -158,8 +159,11 @@ class NoticeHandler:
raw_message=json.dumps(raw_message),
)
logger.info("发送到Maibot处理通知信息")
await message_send_instance.message_send(message_base)
if system_notice:
await self.put_notice(message_base)
else:
logger.info("发送到Maibot处理通知信息")
await message_send_instance.message_send(message_base)
async def handle_poke_notify(self, raw_message: dict, group_id: int, user_id: int) -> Tuple[Seg | None, UserInfo]:
self_info: dict = await get_self_info(self.server_connection)
@ -355,6 +359,15 @@ class NoticeHandler:
)
return seg_data, operator_info
async def put_notice(self, message_base: MessageBase) -> None:
"""
将处理后的通知消息放入通知队列
"""
if notice_queue.full() or unsuccessful_notice_queue.full():
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
else:
await notice_queue.put(message_base)
async def handle_natural_lift(self) -> None:
while True:
if len(self.lifted_list) != 0:
@ -402,11 +415,8 @@ class NoticeHandler:
}
),
)
if notice_queue.full() or unsuccessful_notice_queue.full():
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
else:
await notice_queue.put(message_base)
await self.put_notice(message_base)
await asyncio.sleep(0.5) # 确保队列处理间隔
else:
await asyncio.sleep(5) # 每5秒检查一次
@ -449,6 +459,9 @@ class NoticeHandler:
async def auto_lift_detect(self) -> None:
while True:
if len(self.banned_list) == 0:
await asyncio.sleep(5)
continue
for ban_record in self.banned_list:
if ban_record.user_id == 0 or ban_record.lift_time == -1:
continue
@ -457,7 +470,7 @@ class NoticeHandler:
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
self.lifted_list.append(ban_record)
self.banned_list.remove(ban_record)
asyncio.sleep(5)
await asyncio.sleep(5)
async def send_notice(self) -> None:
"""
@ -475,7 +488,7 @@ class NoticeHandler:
except Exception as e:
logger.error(f"发送通知消息失败: {str(e)}")
await unsuccessful_notice_queue.put(to_be_send)
asyncio.sleep(0.2)
await asyncio.sleep(1)
continue
to_be_send: MessageBase = await notice_queue.get()
try:

View File

@ -21,7 +21,7 @@ class SendHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection