mirror of https://github.com/Mai-with-u/MaiBot.git
240 lines
7.3 KiB
Python
240 lines
7.3 KiB
Python
"""
|
||
TTS后端抽象基类和注册表
|
||
"""
|
||
|
||
import asyncio
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass
|
||
from typing import Dict, Type, Optional, Any, Callable, Tuple, Union
|
||
from src.common.logger import get_logger
|
||
from ..config_keys import ConfigKeys
|
||
|
||
logger = get_logger("tts_backend")
|
||
|
||
|
||
@dataclass
|
||
class TTSResult:
|
||
"""TTS执行结果"""
|
||
success: bool
|
||
message: str
|
||
audio_path: Optional[str] = None
|
||
backend_name: str = ""
|
||
|
||
def __iter__(self):
|
||
"""支持解包为 (success, message)"""
|
||
return iter((self.success, self.message))
|
||
|
||
|
||
class TTSBackendBase(ABC):
|
||
"""
|
||
TTS后端抽象基类
|
||
|
||
所有TTS后端必须继承此类并实现 execute 方法
|
||
"""
|
||
|
||
# 后端名称(子类必须覆盖)
|
||
backend_name: str = "base"
|
||
|
||
# 后端描述
|
||
backend_description: str = "TTS后端基类"
|
||
|
||
# 是否支持私聊
|
||
support_private_chat: bool = True
|
||
|
||
# 默认音频格式
|
||
default_audio_format: str = "mp3"
|
||
|
||
def __init__(self, config_getter: Callable[[str, Any], Any], log_prefix: str = ""):
|
||
"""
|
||
初始化后端
|
||
|
||
Args:
|
||
config_getter: 配置获取函数,签名为 get_config(key, default)
|
||
log_prefix: 日志前缀
|
||
"""
|
||
self.get_config = config_getter
|
||
self.log_prefix = log_prefix or f"[{self.backend_name}]"
|
||
self._send_custom = None
|
||
|
||
def set_send_custom(self, send_custom_func: Callable) -> None:
|
||
"""设置发送自定义消息的函数"""
|
||
self._send_custom = send_custom_func
|
||
|
||
async def send_audio(
|
||
self,
|
||
audio_data: bytes,
|
||
audio_format: str = "mp3",
|
||
prefix: str = "tts",
|
||
voice_info: str = ""
|
||
) -> TTSResult:
|
||
"""
|
||
统一的音频发送方法
|
||
|
||
Args:
|
||
audio_data: 音频二进制数据
|
||
audio_format: 音频格式(如mp3、wav)
|
||
prefix: 文件名前缀
|
||
voice_info: 音色信息(用于日志)
|
||
|
||
Returns:
|
||
TTSResult
|
||
"""
|
||
from ..utils.file import TTSFileManager
|
||
|
||
# 检查是否使用base64发送
|
||
use_base64 = self.get_config(ConfigKeys.GENERAL_USE_BASE64_AUDIO, False)
|
||
logger.debug(f"{self.log_prefix} 开始发送音频 (原始大小: {len(audio_data)}字节, 格式: {audio_format})")
|
||
|
||
if use_base64:
|
||
# 使用base64编码发送
|
||
base64_audio = TTSFileManager.audio_to_base64(audio_data)
|
||
if not base64_audio:
|
||
return TTSResult(False, "音频数据转base64失败", backend_name=self.backend_name)
|
||
|
||
logger.debug(f"{self.log_prefix} base64编码完成,准备通过send_custom发送")
|
||
if self._send_custom:
|
||
await self._send_custom(message_type="voice", content=base64_audio)
|
||
logger.info(f"{self.log_prefix} 语音已通过send_custom发送 (base64模式, 音频大小: {len(audio_data)}字节)")
|
||
else:
|
||
logger.warning(f"{self.log_prefix} send_custom未设置,无法发送语音")
|
||
return TTSResult(False, "send_custom回调未设置", backend_name=self.backend_name)
|
||
|
||
return TTSResult(
|
||
success=True,
|
||
message=f"成功发送{self.backend_name}语音{(' ('+voice_info+')') if voice_info else ''}, base64模式",
|
||
backend_name=self.backend_name
|
||
)
|
||
else:
|
||
# 使用文件路径发送
|
||
output_dir = self.get_config(ConfigKeys.GENERAL_AUDIO_OUTPUT_DIR, "")
|
||
audio_path = TTSFileManager.generate_temp_path(
|
||
prefix=prefix,
|
||
suffix=f".{audio_format}",
|
||
output_dir=output_dir
|
||
)
|
||
|
||
if not await TTSFileManager.write_audio_async(audio_path, audio_data):
|
||
return TTSResult(False, "保存音频文件失败", backend_name=self.backend_name)
|
||
|
||
logger.debug(f"{self.log_prefix} 音频文件已保存, 路径: {audio_path}")
|
||
# 发送语音
|
||
if self._send_custom:
|
||
await self._send_custom(message_type="voiceurl", content=audio_path)
|
||
logger.info(f"{self.log_prefix} 语音已通过send_custom发送 (文件路径模式, 路径: {audio_path})")
|
||
# 延迟清理临时文件
|
||
asyncio.create_task(TTSFileManager.cleanup_file_async(audio_path, delay=30))
|
||
else:
|
||
logger.warning(f"{self.log_prefix} send_custom未设置,无法发送语音")
|
||
return TTSResult(False, "send_custom回调未设置", backend_name=self.backend_name)
|
||
|
||
return TTSResult(
|
||
success=True,
|
||
message=f"成功发送{self.backend_name}语音{(' ('+voice_info+')') if voice_info else ''}",
|
||
audio_path=audio_path,
|
||
backend_name=self.backend_name
|
||
)
|
||
|
||
@abstractmethod
|
||
async def execute(
|
||
self,
|
||
text: str,
|
||
voice: Optional[str] = None,
|
||
**kwargs
|
||
) -> TTSResult:
|
||
"""
|
||
执行TTS转换
|
||
|
||
Args:
|
||
text: 待转换的文本
|
||
voice: 音色/风格
|
||
**kwargs: 其他参数(如emotion等)
|
||
|
||
Returns:
|
||
TTSResult 包含执行结果
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def validate_config(self) -> Tuple[bool, str]:
|
||
"""
|
||
验证后端配置是否完整
|
||
|
||
Returns:
|
||
(is_valid, error_message)
|
||
"""
|
||
return True, ""
|
||
|
||
def get_default_voice(self) -> str:
|
||
"""获取默认音色"""
|
||
return ""
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查后端是否可用"""
|
||
is_valid, _ = self.validate_config()
|
||
return is_valid
|
||
|
||
|
||
class TTSBackendRegistry:
|
||
"""
|
||
TTS后端注册表
|
||
|
||
使用策略模式 + 工厂模式管理后端
|
||
"""
|
||
|
||
_backends: Dict[str, Type[TTSBackendBase]] = {}
|
||
|
||
@classmethod
|
||
def register(cls, name: str, backend_class: Type[TTSBackendBase]) -> None:
|
||
"""
|
||
注册后端
|
||
|
||
Args:
|
||
name: 后端名称
|
||
backend_class: 后端类
|
||
"""
|
||
cls._backends[name] = backend_class
|
||
logger.debug(f"注册TTS后端: {name}")
|
||
|
||
@classmethod
|
||
def unregister(cls, name: str) -> None:
|
||
"""注销后端"""
|
||
if name in cls._backends:
|
||
del cls._backends[name]
|
||
|
||
@classmethod
|
||
def get(cls, name: str) -> Optional[Type[TTSBackendBase]]:
|
||
"""获取后端类"""
|
||
return cls._backends.get(name)
|
||
|
||
@classmethod
|
||
def create(
|
||
cls,
|
||
name: str,
|
||
config_getter: Callable[[str, Any], Any],
|
||
log_prefix: str = ""
|
||
) -> Optional[TTSBackendBase]:
|
||
"""
|
||
创建后端实例
|
||
|
||
Args:
|
||
name: 后端名称
|
||
config_getter: 配置获取函数
|
||
log_prefix: 日志前缀
|
||
|
||
Returns:
|
||
后端实例或None
|
||
"""
|
||
backend_class = cls.get(name)
|
||
if backend_class:
|
||
return backend_class(config_getter, log_prefix)
|
||
return None
|
||
|
||
@classmethod
|
||
def list_backends(cls) -> list[str]:
|
||
"""列出所有已注册的后端名称"""
|
||
return list(cls._backends.keys())
|
||
|
||
@classmethod
|
||
def is_registered(cls, name: str) -> bool:
|
||
"""检查后端是否已注册"""
|
||
return name in cls._backends
|