MaiBot/plugins/tts_voice_plugin/backends/base.py

240 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
TTS后端抽象基类和注册表
"""
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Type, Optional, Any, Callable, Tuple, Union
from src.common.logger import get_logger
from ..config_keys import ConfigKeys
logger = get_logger("tts_backend")
@dataclass
class TTSResult:
"""TTS执行结果"""
success: bool
message: str
audio_path: Optional[str] = None
backend_name: str = ""
def __iter__(self):
"""支持解包为 (success, message)"""
return iter((self.success, self.message))
class TTSBackendBase(ABC):
"""
TTS后端抽象基类
所有TTS后端必须继承此类并实现 execute 方法
"""
# 后端名称(子类必须覆盖)
backend_name: str = "base"
# 后端描述
backend_description: str = "TTS后端基类"
# 是否支持私聊
support_private_chat: bool = True
# 默认音频格式
default_audio_format: str = "mp3"
def __init__(self, config_getter: Callable[[str, Any], Any], log_prefix: str = ""):
"""
初始化后端
Args:
config_getter: 配置获取函数,签名为 get_config(key, default)
log_prefix: 日志前缀
"""
self.get_config = config_getter
self.log_prefix = log_prefix or f"[{self.backend_name}]"
self._send_custom = None
def set_send_custom(self, send_custom_func: Callable) -> None:
"""设置发送自定义消息的函数"""
self._send_custom = send_custom_func
async def send_audio(
self,
audio_data: bytes,
audio_format: str = "mp3",
prefix: str = "tts",
voice_info: str = ""
) -> TTSResult:
"""
统一的音频发送方法
Args:
audio_data: 音频二进制数据
audio_format: 音频格式如mp3、wav
prefix: 文件名前缀
voice_info: 音色信息(用于日志)
Returns:
TTSResult
"""
from ..utils.file import TTSFileManager
# 检查是否使用base64发送
use_base64 = self.get_config(ConfigKeys.GENERAL_USE_BASE64_AUDIO, False)
logger.debug(f"{self.log_prefix} 开始发送音频 (原始大小: {len(audio_data)}字节, 格式: {audio_format})")
if use_base64:
# 使用base64编码发送
base64_audio = TTSFileManager.audio_to_base64(audio_data)
if not base64_audio:
return TTSResult(False, "音频数据转base64失败", backend_name=self.backend_name)
logger.debug(f"{self.log_prefix} base64编码完成准备通过send_custom发送")
if self._send_custom:
await self._send_custom(message_type="voice", content=base64_audio)
logger.info(f"{self.log_prefix} 语音已通过send_custom发送 (base64模式, 音频大小: {len(audio_data)}字节)")
else:
logger.warning(f"{self.log_prefix} send_custom未设置无法发送语音")
return TTSResult(False, "send_custom回调未设置", backend_name=self.backend_name)
return TTSResult(
success=True,
message=f"成功发送{self.backend_name}语音{(' ('+voice_info+')') if voice_info else ''}, base64模式",
backend_name=self.backend_name
)
else:
# 使用文件路径发送
output_dir = self.get_config(ConfigKeys.GENERAL_AUDIO_OUTPUT_DIR, "")
audio_path = TTSFileManager.generate_temp_path(
prefix=prefix,
suffix=f".{audio_format}",
output_dir=output_dir
)
if not await TTSFileManager.write_audio_async(audio_path, audio_data):
return TTSResult(False, "保存音频文件失败", backend_name=self.backend_name)
logger.debug(f"{self.log_prefix} 音频文件已保存, 路径: {audio_path}")
# 发送语音
if self._send_custom:
await self._send_custom(message_type="voiceurl", content=audio_path)
logger.info(f"{self.log_prefix} 语音已通过send_custom发送 (文件路径模式, 路径: {audio_path})")
# 延迟清理临时文件
asyncio.create_task(TTSFileManager.cleanup_file_async(audio_path, delay=30))
else:
logger.warning(f"{self.log_prefix} send_custom未设置无法发送语音")
return TTSResult(False, "send_custom回调未设置", backend_name=self.backend_name)
return TTSResult(
success=True,
message=f"成功发送{self.backend_name}语音{(' ('+voice_info+')') if voice_info else ''}",
audio_path=audio_path,
backend_name=self.backend_name
)
@abstractmethod
async def execute(
self,
text: str,
voice: Optional[str] = None,
**kwargs
) -> TTSResult:
"""
执行TTS转换
Args:
text: 待转换的文本
voice: 音色/风格
**kwargs: 其他参数如emotion等
Returns:
TTSResult 包含执行结果
"""
raise NotImplementedError
def validate_config(self) -> Tuple[bool, str]:
"""
验证后端配置是否完整
Returns:
(is_valid, error_message)
"""
return True, ""
def get_default_voice(self) -> str:
"""获取默认音色"""
return ""
def is_available(self) -> bool:
"""检查后端是否可用"""
is_valid, _ = self.validate_config()
return is_valid
class TTSBackendRegistry:
"""
TTS后端注册表
使用策略模式 + 工厂模式管理后端
"""
_backends: Dict[str, Type[TTSBackendBase]] = {}
@classmethod
def register(cls, name: str, backend_class: Type[TTSBackendBase]) -> None:
"""
注册后端
Args:
name: 后端名称
backend_class: 后端类
"""
cls._backends[name] = backend_class
logger.debug(f"注册TTS后端: {name}")
@classmethod
def unregister(cls, name: str) -> None:
"""注销后端"""
if name in cls._backends:
del cls._backends[name]
@classmethod
def get(cls, name: str) -> Optional[Type[TTSBackendBase]]:
"""获取后端类"""
return cls._backends.get(name)
@classmethod
def create(
cls,
name: str,
config_getter: Callable[[str, Any], Any],
log_prefix: str = ""
) -> Optional[TTSBackendBase]:
"""
创建后端实例
Args:
name: 后端名称
config_getter: 配置获取函数
log_prefix: 日志前缀
Returns:
后端实例或None
"""
backend_class = cls.get(name)
if backend_class:
return backend_class(config_getter, log_prefix)
return None
@classmethod
def list_backends(cls) -> list[str]:
"""列出所有已注册的后端名称"""
return list(cls._backends.keys())
@classmethod
def is_registered(cls, name: str) -> bool:
"""检查后端是否已注册"""
return name in cls._backends