戳一戳解析与ruff规范

pull/3/head
UnCLAS-Prommer 2025-04-11 13:16:21 +08:00
parent c94e55e7e2
commit fd12aac300
10 changed files with 231 additions and 70 deletions

View File

@ -42,9 +42,11 @@ enable_temp = false
- [ ] 回复解析(?)
- [ ] 群临时消息(可能不做)
- [ ] 链接解析
- [ ] 戳一戳解析
- [x] 戳一戳解析
- [ ] 读取戳一戳的自定义内容(?)
- [ ] 语音解析(?)
- [ ] 所有的notice类
- [ ] <del>撤回</del>
- [x] 发送消息
- [x] 发送文本
- [x] 发送图片

12
main.py
View File

@ -22,7 +22,7 @@ async def message_recv(server_connection: Server.ServerConnection):
elif post_type == "message":
await message_queue.put(decoded_raw_message)
elif post_type == "notice":
pass
await message_queue.put(decoded_raw_message)
elif post_type is None:
await recv_queue.put(decoded_raw_message)
@ -36,7 +36,7 @@ async def message_process():
elif post_type == "meta_event":
await recv_handler.handle_meta_event(message)
elif post_type == "notice":
await recv_handler.handle_notify(message)
await recv_handler.handle_notice(message)
else:
logger.warning(f"未知的post_type: {post_type}")
message_queue.task_done()
@ -50,12 +50,8 @@ async def main():
async def napcat_server():
logger.info("正在启动adapter...")
async with Server.serve(
message_recv, global_config.server_host, global_config.server_port
) as server:
logger.info(
f"Adapter已启动监听地址: ws://{global_config.server_host}:{global_config.server_port}"
)
async with Server.serve(message_recv, global_config.server_host, global_config.server_port) as server:
logger.info(f"Adapter已启动监听地址: ws://{global_config.server_host}:{global_config.server_port}")
await server.serve_forever()

44
pyproject.toml 100644
View File

@ -0,0 +1,44 @@
[project]
name = "MaiBotNapcatAdapter"
version = "0.1.0"
description = "A MaiBot adapter for Napcat"
[tool.ruff]
include = ["*.py"]
# 行长度设置
line-length = 120
[tool.ruff.lint]
fixable = ["ALL"]
unfixable = []
# 启用的规则
select = [
"E", # pycodestyle 错误
"F", # pyflakes
"B", # flake8-bugbear
]
ignore = ["E711","E501"]
[tool.ruff.format]
docstring-code-format = true
indent-style = "space"
# 使用双引号表示字符串
quote-style = "double"
# 尊重魔法尾随逗号
# 例如:
# items = [
# "apple",
# "banana",
# "cherry",
# ]
skip-magic-trailing-comma = false
# 自动检测合适的换行符
line-ending = "auto"

View File

@ -1,4 +1,4 @@
class MetaEventType():
class MetaEventType:
lifecycle = "lifecycle" # 生命周期
class Lifecycle:
@ -27,6 +27,7 @@ class MessageType: # 接受消息大类
class NoticeType: # 通知事件
friend_recall = "friend_recall" # 私聊消息撤回
group_recall = "group_recall" # 群聊消息撤回
notify = "notify"
class Notify:
poke = "poke" # 戳一戳
@ -46,13 +47,17 @@ class RealMessageType: # 实际消息分类
share = "share" # 链接分享json形式
reply = "reply" # 回复消息
forward = "forward" # 转发消息
node = "node" # 转发消息节点
node = "node" # 转发消息节点
class MessageSentType:
private = "private"
class Private:
friend = "friend"
group = "group"
group = "group"
class Group:
normal = "normal"
normal = "normal"

View File

@ -34,9 +34,7 @@ class Config:
try:
raw_config = tomli.load(f)
except tomli.TOMLDecodeError as e:
logger.critical(
f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
)
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
sys.exit(1)
for key in include_configs:
if key not in raw_config:

View File

@ -1,5 +1,5 @@
from loguru import logger
import builtins
# import builtins
def handle_output(message: str):

View File

@ -3,7 +3,8 @@ import asyncio
recv_queue = asyncio.Queue()
message_queue = asyncio.Queue()
async def get_response():
response = await recv_queue.get()
recv_queue.task_done()
return response
return response

View File

@ -18,7 +18,13 @@ from maim_message import (
Router,
)
from .utils import get_group_info, get_member_info, get_image_base64, get_self_info
from .utils import (
get_group_info,
get_member_info,
get_image_base64,
get_self_info,
get_stranger_info,
)
from .message_queue import get_response
@ -65,7 +71,8 @@ class RecvHandler:
"""
message_type: str = raw_message.get("message_type")
message_id: int = raw_message.get("message_id")
message_time: int = raw_message.get("time")
# message_time: int = raw_message.get("time")
message_time: float = time.time() # 应可乐要求现在是float了
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
format_info: FormatInfo = None # 格式化信息,暂时为空,等待启用
@ -114,9 +121,7 @@ class RecvHandler:
# -------------------这里需要群信息吗?-------------------
# 获取群聊相关信息在此单独处理group_name因为默认发送的消息中没有
fetched_group_info: dict = await get_group_info(
self.server_connection, raw_message.get("group_id")
)
fetched_group_info: dict = await get_group_info(self.server_connection, raw_message.get("group_id"))
group_name = ""
if fetched_group_info.get("group_name"):
group_name = fetched_group_info.get("group_name")
@ -144,9 +149,7 @@ class RecvHandler:
)
# 获取群聊相关信息在此单独处理group_name因为默认发送的消息中没有
fetched_group_info = await get_group_info(
self.server_connection, raw_message.get("group_id")
)
fetched_group_info = await get_group_info(self.server_connection, raw_message.get("group_id"))
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
@ -319,9 +322,7 @@ class RecvHandler:
seg_data = Seg(type="emoji", data=image_base64)
return seg_data
async def handle_at_message(
self, raw_message: dict, self_id: int, group_id: int
) -> Seg:
async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg:
"""
处理at消息
Parameters:
@ -337,15 +338,11 @@ class RecvHandler:
if str(self_id) == str(qq_id):
self_info: dict = await get_self_info(self.server_connection)
if self_info:
return Seg(
type=RealMessageType.text, data=f"@{self_info.get('nickname')}"
)
return Seg(type=RealMessageType.text, data=f"@{self_info.get('nickname')}")
else:
return None
else:
member_info: dict = await get_member_info(
self.server_connection, group_id=group_id, user_id=self_id
)
member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=self_id)
if member_info:
return Seg(
type=RealMessageType.text,
@ -354,24 +351,130 @@ class RecvHandler:
else:
return None
async def handle_notify(self, raw_message: dict) -> None:
async def handle_notice(self, raw_message: dict) -> None:
notice_type = raw_message.get("notice_type")
# message_time: int = raw_message.get("time")
message_time: float = time.time() # 应可乐要求现在是float了
group_id = raw_message.get("group_id")
user_id = raw_message.get("user_id")
handled_message: Seg = None
match notice_type:
case NoticeType.friend_recall:
logger.info("用户撤回一条消息")
logger.info("好友撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
pass
case NoticeType.group_recall:
logger.info("群内用户撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
pass
case NoticeType.Notify:
case NoticeType.notify:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
logger.info("用户戳了一戳")
pass
async def handle_poke_notify(self) -> None:
pass
handled_message: Seg = await self.handle_poke_notify(raw_message)
if not handled_message:
logger.warning("notice处理失败或不支持")
return None
source_name: str = None
source_cardname: str = None
if group_id:
member_info: dict = await get_member_info(self.server_connection, group_id, user_id)
if member_info:
source_name = member_info.get("nickname")
source_cardname = member_info.get("card")
else:
logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效")
source_name = "QQ用户"
else:
stranger_info = await get_stranger_info(self.server_connection, user_id)
if stranger_info:
source_name = stranger_info.get("nickname")
else:
logger.warning("无法获取戳一戳消息发送者的昵称,消息可能会无效")
source_name = "QQ用户"
user_info: UserInfo = UserInfo(
platform=global_config.platform,
user_id=user_id,
user_nickname=source_name,
user_cardname=source_cardname,
)
group_info: GroupInfo = None
if group_id:
fetched_group_info = await get_group_info(self.server_connection, group_id)
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
group_info = GroupInfo(
platform=global_config.platform,
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.platform,
message_id="notice",
time=message_time,
user_info=user_info,
group_info=group_info,
template_info=None,
format_info=None,
)
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=handled_message,
raw_message=json.dumps(raw_message),
)
logger.info("发送到Maibot处理通知信息")
await self.message_process(message_base)
async def handle_poke_notify(self, raw_message: dict) -> Seg:
self_info: dict = await get_self_info(self.server_connection)
if not self_info:
logger.error("自身信息获取失败")
return None
self_id = raw_message.get("self_id")
target_id = raw_message.get("target_id")
target_name: str = None
raw_info: list = raw_message.get("raw_info")
# 计算Seg
if self_id == target_id:
if self_info:
target_name = self_info.get("nickname")
else:
logger.warning("无法获取bot的昵称戳一戳消息可能无效")
target_name = ""
try:
first_txt = raw_info[2].get("text", "戳了戳")
second_txt = raw_info[4].get("text", "")
except Exception as e:
logger.warning(f"解析戳一戳消息失败,使用默认文本:{str(e)}")
first_txt = "戳了戳"
second_txt = ""
"""
# 不启用戳其他人的处理
else:
# 由于Napcat不支持获取昵称所以需要单独获取
group_id = raw_message.get("group_id")
fetched_member_info: dict = await get_member_info(
self.server_connection, group_id, target_id
)
if fetched_member_info:
target_name = fetched_member_info.get("nickname")
"""
seg_data: Seg = Seg(
type="text",
data=f"{first_txt}{target_name}{second_txt}(这是一个类似摸摸头的友善行为,不是恶意行为,请不要作出攻击发言)",
)
return seg_data
async def handle_forward_message(self, message_list: list) -> Seg:
"""
@ -379,24 +482,18 @@ class RecvHandler:
Parameters:
message_list: list: 转发消息列表
"""
handled_message, image_count = await self._handle_forward_message(
message_list, 0
)
handled_message, image_count = await self._handle_forward_message(message_list, 0)
handled_message: Seg
image_count: int
if not handled_message:
return None
if image_count < 5 and image_count > 0:
# 处理图片数量小于5的情况此时解析图片为base64
parsed_handled_message = await self._recursive_parse_image_seg(
handled_message, True
)
parsed_handled_message = await self._recursive_parse_image_seg(handled_message, True)
return parsed_handled_message
elif image_count > 0:
# 处理图片数量大于等于5的情况此时解析图片为占位符
parsed_handled_message = await self._recursive_parse_image_seg(
handled_message, False
)
parsed_handled_message = await self._recursive_parse_image_seg(handled_message, False)
return parsed_handled_message
else:
# 处理没有图片的情况,此时直接返回
@ -436,9 +533,7 @@ class RecvHandler:
else:
return seg_data
async def _handle_forward_message(
self, message_list: list, layer: int
) -> Tuple[Seg, int]:
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int]:
"""
递归处理实际转发消息
Parameters:
@ -470,14 +565,11 @@ class RecvHandler:
)
else:
contents = message_of_sub_message.get("data").get("content")
seg_data, count = await self._handle_forward_message(
contents, layer + 1
)
seg_data, count = await self._handle_forward_message(contents, layer + 1)
image_count += count
head_tip = Seg(
type="text",
data=("--" * layer)
+ f"{user_nickname}】: 合并转发消息内容:\n",
data=("--" * layer) + f"{user_nickname}】: 合并转发消息内容:\n",
)
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
seg_list.append(full_seg_data)
@ -489,9 +581,7 @@ class RecvHandler:
Seg(
type="seglist",
data=[
Seg(
type="text", data=("--" * layer) + user_nickname_str
),
Seg(type="text", data=("--" * layer) + user_nickname_str),
seg_data,
break_seg,
],

View File

@ -74,7 +74,7 @@ class SendHandler:
async def handle_seg_recursive(self, seg_data: Seg) -> list:
payload: list = []
if seg_data.type == "seglist":
level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
# level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
for seg in seg_data.data:
payload = self.process_message_by_type(seg, payload)
else:
@ -85,6 +85,8 @@ class SendHandler:
new_payload = payload
if seg.type == "reply":
target_id = seg.data
if target_id == "notice":
return []
new_payload = self.build_payload(
payload, self.handle_reply_message(target_id), True
)

View File

@ -21,9 +21,7 @@ class SSLAdapter(HTTPAdapter):
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
ssl_context.check_hostname = False # 避免在请求时 verify=False 设置时报错, 如果设置需要校验证书可去掉该行。
ssl_context.minimum_version = (
ssl.TLSVersion.TLSv1_2
) # 最小版本设置成1.2 可去掉低版本的警告
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 # 最小版本设置成1.2 可去掉低版本的警告
ssl_context.maximum_version = ssl.TLSVersion.TLSv1_2 # 最大版本设置成1.2
kwargs["ssl_context"] = ssl_context
return super().init_poolmanager(*args, **kwargs)
@ -42,9 +40,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
return socket_response.get("data")
async def get_member_info(
websocket: Server.ServerConnection, group_id: int, user_id: int
) -> dict:
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict:
"""
获取群成员信息
@ -89,9 +85,13 @@ def convert_image_to_gif(image_base64: str) -> str:
return image_base64
async def get_self_info(websocket: Server.ServerConnection) -> str:
async def get_self_info(websocket: Server.ServerConnection) -> dict:
"""
获取自身信息
Parameters:
websocket: WebSocket连接对象
Returns:
data: dict: 返回的自身信息
"""
payload = json.dumps({"action": "get_login_info", "params": {}})
await websocket.send(payload)
@ -101,5 +101,28 @@ async def get_self_info(websocket: Server.ServerConnection) -> str:
def get_image_format(raw_data: str) -> str:
"""
从Base64编码的数据中确定图片的格式
Parameters:
raw_data: str: Base64编码的图片数据
Returns:
format: str: 图片的格式例如 'jpeg', 'png', 'gif'
"""
image_bytes = base64.b64decode(raw_data)
return Image.open(io.BytesIO(image_bytes)).format.lower()
async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict:
"""
获取陌生人信息
Parameters:
websocket: WebSocket连接对象
user_id: 用户ID
Returns:
dict: 返回的陌生人信息
"""
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}})
await websocket.send(payload)
response: dict = await get_response()
logger.debug(response)
return response.get("data")