语音消息识别
parent
92af300035
commit
60e9106a89
6
main.py
6
main.py
|
|
@ -20,11 +20,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||||
await send_handler.set_server_connection(server_connection)
|
await send_handler.set_server_connection(server_connection)
|
||||||
async for raw_message in server_connection:
|
async for raw_message in server_connection:
|
||||||
logger.debug(
|
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||||
f"{raw_message[:100]}..."
|
|
||||||
if (len(raw_message) > 100 and global_config.debug.level != "DEBUG")
|
|
||||||
else raw_message
|
|
||||||
)
|
|
||||||
decoded_raw_message: dict = json.loads(raw_message)
|
decoded_raw_message: dict = json.loads(raw_message)
|
||||||
post_type = decoded_raw_message.get("post_type")
|
post_type = decoded_raw_message.get("post_type")
|
||||||
if post_type in ["meta_event", "message", "notice"]:
|
if post_type in ["meta_event", "message", "notice"]:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "MaiBotNapcatAdapter"
|
name = "MaiBotNapcatAdapter"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
description = "A MaiBot adapter for Napcat"
|
description = "A MaiBot adapter for Napcat"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|
|
||||||
|
|
@ -82,3 +82,6 @@ class CommandType(Enum):
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command"]
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,13 @@ from src.utils import (
|
||||||
get_group_info,
|
get_group_info,
|
||||||
get_member_info,
|
get_member_info,
|
||||||
get_image_base64,
|
get_image_base64,
|
||||||
|
get_record_detail,
|
||||||
get_self_info,
|
get_self_info,
|
||||||
get_message_detail,
|
get_message_detail,
|
||||||
)
|
)
|
||||||
from .qq_emoji_list import qq_face
|
from .qq_emoji_list import qq_face
|
||||||
from .message_sending import message_send_instance
|
from .message_sending import message_send_instance
|
||||||
from . import RealMessageType, MessageType
|
from . import RealMessageType, MessageType, ACCEPT_FORMAT
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
@ -108,8 +109,8 @@ class MessageHandler:
|
||||||
|
|
||||||
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
|
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
|
||||||
format_info: FormatInfo = FormatInfo(
|
format_info: FormatInfo = FormatInfo(
|
||||||
content_format=["text", "image", "emoji"],
|
content_format=["text", "image", "emoji", "voice"],
|
||||||
accept_format=["text", "image", "emoji", "reply", "voice", "command"],
|
accept_format=ACCEPT_FORMAT,
|
||||||
) # 格式化信息
|
) # 格式化信息
|
||||||
if message_type == MessageType.private:
|
if message_type == MessageType.private:
|
||||||
sub_type = raw_message.get("sub_type")
|
sub_type = raw_message.get("sub_type")
|
||||||
|
|
@ -285,7 +286,13 @@ class MessageHandler:
|
||||||
else:
|
else:
|
||||||
logger.warning("image处理失败")
|
logger.warning("image处理失败")
|
||||||
case RealMessageType.record:
|
case RealMessageType.record:
|
||||||
logger.warning("不支持语音解析")
|
ret_seg = await self.handle_record_message(sub_message)
|
||||||
|
if ret_seg:
|
||||||
|
seg_message.clear()
|
||||||
|
seg_message.append(ret_seg)
|
||||||
|
break # 使得消息只有record消息
|
||||||
|
else:
|
||||||
|
logger.warning("record处理失败或不支持")
|
||||||
case RealMessageType.video:
|
case RealMessageType.video:
|
||||||
logger.warning("不支持视频解析")
|
logger.warning("不支持视频解析")
|
||||||
case RealMessageType.at:
|
case RealMessageType.at:
|
||||||
|
|
@ -405,6 +412,27 @@ class MessageHandler:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def handle_record_message(self, raw_message: dict) -> Seg | None:
|
||||||
|
"""
|
||||||
|
处理语音消息
|
||||||
|
Parameters:
|
||||||
|
raw_message: dict: 原始消息
|
||||||
|
Returns:
|
||||||
|
seg_data: Seg: 处理后的消息段
|
||||||
|
"""
|
||||||
|
message_data: dict = raw_message.get("data")
|
||||||
|
file: str = message_data.get("file")
|
||||||
|
try:
|
||||||
|
record_detail = await get_record_detail(self.server_connection, file)
|
||||||
|
audio_base64: str = record_detail.get("base64")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"语音消息处理失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
if not audio_base64:
|
||||||
|
logger.error("语音消息处理失败,未获取到音频数据")
|
||||||
|
return None
|
||||||
|
return Seg(type="voice", data=audio_base64)
|
||||||
|
|
||||||
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
|
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
|
||||||
# sourcery skip: move-assign-in-block, use-named-expression
|
# sourcery skip: move-assign-in-block, use-named-expression
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ class MetaEventHandler:
|
||||||
if sub_type == MetaEventType.Lifecycle.connect:
|
if sub_type == MetaEventType.Lifecycle.connect:
|
||||||
self_id = message.get("self_id")
|
self_id = message.get("self_id")
|
||||||
self.last_heart_beat = time.time()
|
self.last_heart_beat = time.time()
|
||||||
logger.info(f"Bot {self_id} 连接成功")
|
logger.success(f"Bot {self_id} 连接成功")
|
||||||
asyncio.create_task(self.check_heartbeat(self_id))
|
asyncio.create_task(self.check_heartbeat(self_id))
|
||||||
elif event_type == MetaEventType.heartbeat:
|
elif event_type == MetaEventType.heartbeat:
|
||||||
if message["status"].get("online") and message["status"].get("good"):
|
if message["status"].get("online") and message["status"].get("good"):
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@ from typing import Tuple, Optional
|
||||||
from src.logger import logger
|
from src.logger import logger
|
||||||
from src.config import global_config
|
from src.config import global_config
|
||||||
from src.database import BanUser, db_manager, is_identical
|
from src.database import BanUser, db_manager, is_identical
|
||||||
from . import NoticeType
|
from . import NoticeType, ACCEPT_FORMAT
|
||||||
from .message_sending import message_send_instance
|
from .message_sending import message_send_instance
|
||||||
from .message_handler import message_handler
|
from .message_handler import message_handler
|
||||||
from maim_message import UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
||||||
|
|
||||||
from src.utils import (
|
from src.utils import (
|
||||||
get_group_info,
|
get_group_info,
|
||||||
|
|
@ -151,7 +151,10 @@ class NoticeHandler:
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
template_info=None,
|
template_info=None,
|
||||||
format_info=None,
|
format_info=FormatInfo(
|
||||||
|
content_format=["text", "notify"],
|
||||||
|
accept_format=ACCEPT_FORMAT,
|
||||||
|
),
|
||||||
additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁
|
additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -170,6 +173,7 @@ class NoticeHandler:
|
||||||
async def handle_poke_notify(
|
async def handle_poke_notify(
|
||||||
self, raw_message: dict, group_id: int, user_id: int
|
self, raw_message: dict, group_id: int, user_id: int
|
||||||
) -> Tuple[Seg | None, UserInfo | None]:
|
) -> Tuple[Seg | None, UserInfo | None]:
|
||||||
|
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
|
||||||
self_info: dict = await get_self_info(self.server_connection)
|
self_info: dict = await get_self_info(self.server_connection)
|
||||||
|
|
||||||
if not self_info:
|
if not self_info:
|
||||||
|
|
|
||||||
36
src/utils.py
36
src/utils.py
|
|
@ -11,7 +11,7 @@ from .logger import logger
|
||||||
from .response_pool import get_response
|
from .response_pool import get_response
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Union, List, Tuple
|
from typing import Union, List, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
class SSLAdapter(urllib3.PoolManager):
|
class SSLAdapter(urllib3.PoolManager):
|
||||||
|
|
@ -219,6 +219,40 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni
|
||||||
return response.get("data")
|
return response.get("data")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_record_detail(
|
||||||
|
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
|
||||||
|
) -> dict | None:
|
||||||
|
"""
|
||||||
|
获取语音消息内容
|
||||||
|
Parameters:
|
||||||
|
websocket: WebSocket连接对象
|
||||||
|
file: 文件名
|
||||||
|
file_id: 文件ID
|
||||||
|
Returns:
|
||||||
|
dict: 返回的语音消息详情
|
||||||
|
"""
|
||||||
|
logger.debug("获取语音消息详情中")
|
||||||
|
request_uuid = str(uuid.uuid4())
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"action": "get_record",
|
||||||
|
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
|
||||||
|
"echo": request_uuid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await websocket.send(payload)
|
||||||
|
response: dict = await get_response(request_uuid)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取语音消息详情失败: {e}")
|
||||||
|
return None
|
||||||
|
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
|
||||||
|
return response.get("data")
|
||||||
|
|
||||||
|
|
||||||
async def read_ban_list(
|
async def read_ban_list(
|
||||||
websocket: Server.ServerConnection,
|
websocket: Server.ServerConnection,
|
||||||
) -> Tuple[List[BanUser], List[BanUser]]:
|
) -> Tuple[List[BanUser], List[BanUser]]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue