语音消息识别
parent
92af300035
commit
60e9106a89
6
main.py
6
main.py
|
|
@ -20,11 +20,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
|||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||
await send_handler.set_server_connection(server_connection)
|
||||
async for raw_message in server_connection:
|
||||
logger.debug(
|
||||
f"{raw_message[:100]}..."
|
||||
if (len(raw_message) > 100 and global_config.debug.level != "DEBUG")
|
||||
else raw_message
|
||||
)
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "MaiBotNapcatAdapter"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
description = "A MaiBot adapter for Napcat"
|
||||
|
||||
[tool.ruff]
|
||||
|
|
|
|||
|
|
@ -82,3 +82,6 @@ class CommandType(Enum):
|
|||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command"]
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ from src.utils import (
|
|||
get_group_info,
|
||||
get_member_info,
|
||||
get_image_base64,
|
||||
get_record_detail,
|
||||
get_self_info,
|
||||
get_message_detail,
|
||||
)
|
||||
from .qq_emoji_list import qq_face
|
||||
from .message_sending import message_send_instance
|
||||
from . import RealMessageType, MessageType
|
||||
from . import RealMessageType, MessageType, ACCEPT_FORMAT
|
||||
|
||||
import time
|
||||
import json
|
||||
|
|
@ -108,8 +109,8 @@ class MessageHandler:
|
|||
|
||||
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
|
||||
format_info: FormatInfo = FormatInfo(
|
||||
content_format=["text", "image", "emoji"],
|
||||
accept_format=["text", "image", "emoji", "reply", "voice", "command"],
|
||||
content_format=["text", "image", "emoji", "voice"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
) # 格式化信息
|
||||
if message_type == MessageType.private:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
|
|
@ -285,7 +286,13 @@ class MessageHandler:
|
|||
else:
|
||||
logger.warning("image处理失败")
|
||||
case RealMessageType.record:
|
||||
logger.warning("不支持语音解析")
|
||||
ret_seg = await self.handle_record_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.clear()
|
||||
seg_message.append(ret_seg)
|
||||
break # 使得消息只有record消息
|
||||
else:
|
||||
logger.warning("record处理失败或不支持")
|
||||
case RealMessageType.video:
|
||||
logger.warning("不支持视频解析")
|
||||
case RealMessageType.at:
|
||||
|
|
@ -405,6 +412,27 @@ class MessageHandler:
|
|||
else:
|
||||
return None
|
||||
|
||||
async def handle_record_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理语音消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
file: str = message_data.get("file")
|
||||
try:
|
||||
record_detail = await get_record_detail(self.server_connection, file)
|
||||
audio_base64: str = record_detail.get("base64")
|
||||
except Exception as e:
|
||||
logger.error(f"语音消息处理失败: {str(e)}")
|
||||
return None
|
||||
if not audio_base64:
|
||||
logger.error("语音消息处理失败,未获取到音频数据")
|
||||
return None
|
||||
return Seg(type="voice", data=audio_base64)
|
||||
|
||||
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
|
||||
# sourcery skip: move-assign-in-block, use-named-expression
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class MetaEventHandler:
|
|||
if sub_type == MetaEventType.Lifecycle.connect:
|
||||
self_id = message.get("self_id")
|
||||
self.last_heart_beat = time.time()
|
||||
logger.info(f"Bot {self_id} 连接成功")
|
||||
logger.success(f"Bot {self_id} 连接成功")
|
||||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
elif event_type == MetaEventType.heartbeat:
|
||||
if message["status"].get("online") and message["status"].get("good"):
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from typing import Tuple, Optional
|
|||
from src.logger import logger
|
||||
from src.config import global_config
|
||||
from src.database import BanUser, db_manager, is_identical
|
||||
from . import NoticeType
|
||||
from . import NoticeType, ACCEPT_FORMAT
|
||||
from .message_sending import message_send_instance
|
||||
from .message_handler import message_handler
|
||||
from maim_message import UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
||||
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
||||
|
||||
from src.utils import (
|
||||
get_group_info,
|
||||
|
|
@ -151,7 +151,10 @@ class NoticeHandler:
|
|||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
template_info=None,
|
||||
format_info=None,
|
||||
format_info=FormatInfo(
|
||||
content_format=["text", "notify"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
),
|
||||
additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁
|
||||
)
|
||||
|
||||
|
|
@ -170,6 +173,7 @@ class NoticeHandler:
|
|||
async def handle_poke_notify(
|
||||
self, raw_message: dict, group_id: int, user_id: int
|
||||
) -> Tuple[Seg | None, UserInfo | None]:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
|
||||
self_info: dict = await get_self_info(self.server_connection)
|
||||
|
||||
if not self_info:
|
||||
|
|
|
|||
36
src/utils.py
36
src/utils.py
|
|
@ -11,7 +11,7 @@ from .logger import logger
|
|||
from .response_pool import get_response
|
||||
|
||||
from PIL import Image
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Optional
|
||||
|
||||
|
||||
class SSLAdapter(urllib3.PoolManager):
|
||||
|
|
@ -219,6 +219,40 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni
|
|||
return response.get("data")
|
||||
|
||||
|
||||
async def get_record_detail(
|
||||
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取语音消息内容
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
file: 文件名
|
||||
file_id: 文件ID
|
||||
Returns:
|
||||
dict: 返回的语音消息详情
|
||||
"""
|
||||
logger.debug("获取语音消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_record",
|
||||
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取语音消息详情失败: {e}")
|
||||
return None
|
||||
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def read_ban_list(
|
||||
websocket: Server.ServerConnection,
|
||||
) -> Tuple[List[BanUser], List[BanUser]]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue