MaiBot/utils/session.py

187 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
HTTP Session 管理器
提供连接池复用,避免每次请求创建新连接
"""
import asyncio
import aiohttp
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from src.common.logger import get_logger
logger = get_logger("tts_session_manager")
class TTSSessionManager:
"""
TTS HTTP Session 管理器
提供:
- 连接池复用
- 自动超时管理
- 优雅关闭
"""
_instance: Optional["TTSSessionManager"] = None
_lock = asyncio.Lock()
def __init__(self):
self._sessions: Dict[str, aiohttp.ClientSession] = {}
self._default_timeout = 60
@classmethod
async def get_instance(cls) -> "TTSSessionManager":
"""获取单例实例"""
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def get_session(
self,
backend_name: str = "default",
timeout: int = None
) -> aiohttp.ClientSession:
"""
获取或创建 HTTP Session
Args:
backend_name: 后端名称用于区分不同的session
timeout: 超时时间(秒)
Returns:
aiohttp.ClientSession 实例
"""
if backend_name not in self._sessions or self._sessions[backend_name].closed:
timeout_val = timeout or self._default_timeout
connector = aiohttp.TCPConnector(
limit=10, # 每个主机最大连接数
limit_per_host=5,
ttl_dns_cache=300, # DNS缓存5分钟
force_close=True, # 禁用连接复用修复GSV2P等API的兼容性问题
)
self._sessions[backend_name] = aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=timeout_val)
)
logger.debug(f"创建新的HTTP Session: {backend_name}")
return self._sessions[backend_name]
async def close_session(self, backend_name: str = None):
"""
关闭指定或所有 Session
Args:
backend_name: 后端名称为None时关闭所有
"""
if backend_name:
if backend_name in self._sessions:
await self._sessions[backend_name].close()
del self._sessions[backend_name]
logger.debug(f"关闭HTTP Session: {backend_name}")
else:
for name, session in self._sessions.items():
if not session.closed:
await session.close()
logger.debug(f"关闭HTTP Session: {name}")
self._sessions.clear()
@asynccontextmanager
async def post(
self,
url: str,
json: Dict[str, Any] = None,
headers: Dict[str, str] = None,
data: Any = None,
backend_name: str = "default",
timeout: int = None
):
"""
发送POST请求异步上下文管理器
Args:
url: 请求URL
json: JSON请求体
headers: 请求头
data: 表单数据
backend_name: 后端名称
timeout: 超时时间
Yields:
aiohttp.ClientResponse
Usage:
async with session_manager.post(url, json=data) as response:
...
"""
session = await self.get_session(backend_name, timeout)
# 如果指定了不同的超时时间,创建新的超时对象
req_timeout = None
if timeout:
req_timeout = aiohttp.ClientTimeout(total=timeout)
response = await session.post(
url,
json=json,
headers=headers,
data=data,
timeout=req_timeout
)
try:
yield response
finally:
response.release()
@asynccontextmanager
async def get(
self,
url: str,
headers: Dict[str, str] = None,
params: Dict[str, Any] = None,
backend_name: str = "default",
timeout: int = None
):
"""
发送GET请求异步上下文管理器
Args:
url: 请求URL
headers: 请求头
params: URL参数
backend_name: 后端名称
timeout: 超时时间
Yields:
aiohttp.ClientResponse
Usage:
async with session_manager.get(url) as response:
...
"""
session = await self.get_session(backend_name, timeout)
# 如果指定了不同的超时时间,创建新的超时对象
req_timeout = None
if timeout:
req_timeout = aiohttp.ClientTimeout(total=timeout)
response = await session.get(
url,
headers=headers,
params=params,
timeout=req_timeout
)
try:
yield response
finally:
response.release()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close_session()