MaiBot/plugins/tts_voice_plugin/backends/gpt_sovits.py

327 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
GPT-SoVITS 后端实现
使用本地 GPT-SoVITS 服务进行语音合成
"""
import asyncio
from typing import Optional, Dict, Any, Tuple, ClassVar
from .base import TTSBackendBase, TTSResult
from ..utils.text import TTSTextUtils
from ..utils.file import TTSFileManager
from ..utils.session import TTSSessionManager
from ..config_keys import ConfigKeys
from src.common.logger import get_logger
logger = get_logger("tts_gpt_sovits")
class GPTSoVITSBackend(TTSBackendBase):
"""
GPT-SoVITS 后端
使用本地 GPT-SoVITS 服务进行高度定制化的语音合成
支持动态切换 GPT 和 SoVITS 模型权重
"""
backend_name = "gpt_sovits"
backend_description = "本地GPT-SoVITS服务"
support_private_chat = True
default_audio_format = "mp3"
# 类变量:记录当前加载的模型路径,避免重复切换
_current_gpt_weights: ClassVar[Optional[str]] = None
_current_sovits_weights: ClassVar[Optional[str]] = None
def get_default_voice(self) -> str:
"""获取默认风格"""
return "default"
async def _switch_model(
self,
server: str,
gpt_weights: Optional[str],
sovits_weights: Optional[str],
timeout: int
) -> Tuple[bool, str]:
"""
切换 GPT-SoVITS 模型权重
Args:
server: 服务器地址
gpt_weights: GPT 模型权重路径
sovits_weights: SoVITS 模型权重路径
timeout: 超时时间
Returns:
(success, error_message)
"""
session_manager = await TTSSessionManager.get_instance()
async def _set_model_v1() -> Tuple[bool, str]:
# 兼容旧版 api.py: 仅支持 /set_model 同时切换
if not gpt_weights or not sovits_weights:
return False, "当前GPT-SoVITS服务不支持单独切换模型请同时配置GPT与SoVITS权重"
set_model_url = (
f"{server.rstrip('/')}/set_model?"
f"gpt_model_path={gpt_weights}&sovits_model_path={sovits_weights}"
)
logger.info(f"{self.log_prefix} 切换模型(兼容模式): {gpt_weights} | {sovits_weights}")
try:
async with session_manager.get(
set_model_url,
backend_name="gpt_sovits",
timeout=timeout
) as response:
if response.status == 200:
GPTSoVITSBackend._current_gpt_weights = gpt_weights
GPTSoVITSBackend._current_sovits_weights = sovits_weights
logger.info(f"{self.log_prefix} 模型切换成功(兼容模式)")
return True, ""
error_text = await response.text()
return False, f"模型切换失败: {error_text}"
except Exception as e:
return False, f"模型切换异常: {e}"
# 切换 GPT 权重
if gpt_weights and gpt_weights != GPTSoVITSBackend._current_gpt_weights:
gpt_url = f"{server.rstrip('/')}/set_gpt_weights?weights_path={gpt_weights}"
logger.info(f"{self.log_prefix} 切换GPT模型: {gpt_weights}")
try:
async with session_manager.get(
gpt_url,
backend_name="gpt_sovits",
timeout=timeout
) as response:
if response.status == 200:
GPTSoVITSBackend._current_gpt_weights = gpt_weights
logger.info(f"{self.log_prefix} GPT模型切换成功")
elif response.status == 404:
# 旧版服务没有 /set_gpt_weights
return await _set_model_v1()
else:
error_text = await response.text()
return False, f"GPT模型切换失败: {error_text}"
except Exception as e:
return False, f"GPT模型切换异常: {e}"
# 切换 SoVITS 权重
if sovits_weights and sovits_weights != GPTSoVITSBackend._current_sovits_weights:
sovits_url = f"{server.rstrip('/')}/set_sovits_weights?weights_path={sovits_weights}"
logger.info(f"{self.log_prefix} 切换SoVITS模型: {sovits_weights}")
try:
async with session_manager.get(
sovits_url,
backend_name="gpt_sovits",
timeout=timeout
) as response:
if response.status == 200:
GPTSoVITSBackend._current_sovits_weights = sovits_weights
logger.info(f"{self.log_prefix} SoVITS模型切换成功")
elif response.status == 404:
# 旧版服务没有 /set_sovits_weights
return await _set_model_v1()
else:
error_text = await response.text()
return False, f"SoVITS模型切换失败: {error_text}"
except Exception as e:
return False, f"SoVITS模型切换异常: {e}"
return True, ""
def _normalize_styles_config(self, styles_config: Any) -> Dict[str, Any]:
"""
规范化风格配置格式
支持两种格式:
1. 旧格式(字典): {"default": {...}, "happy": {...}}
2. 新格式(数组): [{"name": "default", ...}, {"name": "happy", ...}]
统一转换为字典格式供内部使用
"""
# 如果是字典格式(旧格式),直接返回
if isinstance(styles_config, dict):
return styles_config
# 如果是数组格式(新格式),转换为字典
if isinstance(styles_config, list):
result = {}
for style in styles_config:
if isinstance(style, dict) and "name" in style:
style_name = style["name"]
# 复制配置,移除 name 字段
style_data = {k: v for k, v in style.items() if k != "name"}
result[style_name] = style_data
return result
# 其他情况返回空字典
return {}
def validate_config(self) -> Tuple[bool, str]:
"""验证配置"""
styles_raw = self.get_config(ConfigKeys.GPT_SOVITS_STYLES, {})
styles = self._normalize_styles_config(styles_raw)
if not styles or "default" not in styles:
return False, "GPT-SoVITS未配置任何语音风格"
default_style = styles.get("default", {})
if not default_style.get("refer_wav") or not default_style.get("prompt_text"):
return False, "GPT-SoVITS默认风格配置不完整需要refer_wav和prompt_text"
return True, ""
async def execute(
self,
text: str,
voice: Optional[str] = None,
**kwargs
) -> TTSResult:
"""
执行GPT-SoVITS语音合成
Args:
text: 待转换的文本
voice: 风格名称
Returns:
TTSResult
"""
# 验证文本
if not text or not text.strip():
return TTSResult(False, "待合成的文本为空", backend_name=self.backend_name)
# 获取配置
server = self.get_config(ConfigKeys.GPT_SOVITS_SERVER, "http://127.0.0.1:9880")
styles_raw = self.get_config(ConfigKeys.GPT_SOVITS_STYLES, {})
styles = self._normalize_styles_config(styles_raw)
timeout = self.get_config(ConfigKeys.GENERAL_TIMEOUT, 60)
# 确定使用的风格
voice_style = voice if voice and voice in styles else "default"
if voice_style not in styles:
return TTSResult(
False,
f"GPT-SoVITS风格 '{voice_style}' 未配置",
backend_name=self.backend_name
)
style_config = styles[voice_style]
refer_wav_path = style_config.get("refer_wav", "")
prompt_text = style_config.get("prompt_text", "")
prompt_language = style_config.get("prompt_language", "zh")
gpt_weights = style_config.get("gpt_weights")
sovits_weights = style_config.get("sovits_weights")
if not refer_wav_path or not prompt_text:
return TTSResult(
False,
f"GPT-SoVITS风格 '{voice_style}' 配置不完整",
backend_name=self.backend_name
)
# 如果配置了模型权重,先切换模型
if gpt_weights or sovits_weights:
switch_success, switch_error = await self._switch_model(
server, gpt_weights, sovits_weights, timeout
)
if not switch_success:
return TTSResult(False, switch_error, backend_name=self.backend_name)
# 检测文本语言
text_language = TTSTextUtils.detect_language(text)
# 构建请求数据
data = {
"text": text,
"text_lang": text_language,
"ref_audio_path": refer_wav_path,
"prompt_text": prompt_text,
"prompt_lang": prompt_language
}
tts_url = f"{server.rstrip('/')}/tts"
legacy_tts_url = f"{server.rstrip('/')}/"
legacy_data = {
"text": text,
"text_language": text_language,
"refer_wav_path": refer_wav_path,
"prompt_text": prompt_text,
"prompt_language": prompt_language,
}
logger.info(f"{self.log_prefix} GPT-SoVITS请求: text='{text[:50]}...', style={voice_style}")
try:
session_manager = await TTSSessionManager.get_instance()
async with session_manager.post(
tts_url,
json=data,
backend_name="gpt_sovits",
timeout=timeout
) as response:
if response.status == 200:
audio_data = await response.read()
# 验证音频数据
is_valid, error_msg = TTSFileManager.validate_audio_data(audio_data)
if not is_valid:
return TTSResult(False, f"GPT-SoVITS{error_msg}", backend_name=self.backend_name)
# 使用统一的发送方法
return await self.send_audio(
audio_data=audio_data,
audio_format="wav",
prefix="tts_gpt_sovits",
voice_info=f"风格: {voice_style}"
)
elif response.status == 404:
# 兼容旧版 api.py没有 /tts 端点,回退到根路径
logger.warning(f"{self.log_prefix} /tts 端点不存在,尝试兼容模式请求根路径")
else:
error_info = await response.text()
logger.error(f"{self.log_prefix} GPT-SoVITS API失败[{response.status}]: {error_info[:200]}")
return TTSResult(
False,
f"GPT-SoVITS API调用失败: {response.status}",
backend_name=self.backend_name
)
# 仅在 /tts 404 时回退到旧版根路径
async with session_manager.post(
legacy_tts_url,
json=legacy_data,
backend_name="gpt_sovits",
timeout=timeout
) as response:
if response.status == 200:
audio_data = await response.read()
# 验证音频数据
is_valid, error_msg = TTSFileManager.validate_audio_data(audio_data)
if not is_valid:
return TTSResult(False, f"GPT-SoVITS{error_msg}", backend_name=self.backend_name)
return await self.send_audio(
audio_data=audio_data,
audio_format="wav",
prefix="tts_gpt_sovits",
voice_info=f"风格: {voice_style}"
)
else:
error_info = await response.text()
logger.error(f"{self.log_prefix} GPT-SoVITS API失败[{response.status}]: {error_info[:200]}")
return TTSResult(
False,
f"GPT-SoVITS API调用失败: {response.status}",
backend_name=self.backend_name
)
except asyncio.TimeoutError:
return TTSResult(False, "GPT-SoVITS API调用超时", backend_name=self.backend_name)
except Exception as e:
logger.error(f"{self.log_prefix} GPT-SoVITS执行错误: {e}")
return TTSResult(False, f"GPT-SoVITS执行错误: {e}", backend_name=self.backend_name)