fix typing of utils_model.py

pull/1240/head
UnCLAS-Prommer 2025-09-17 15:59:02 +08:00
parent 91e716a24c
commit 1260a11b78
No known key found for this signature in database
2 changed files with 40 additions and 18 deletions

View File

@ -14,6 +14,7 @@ from src.plugin_system import (
MaiMessages,
ToolParamType,
ReplyContentType,
emoji_api,
)
from src.config.config import global_config
@ -181,7 +182,26 @@ class ForwardMessages(BaseEventHandler):
raise ValueError("转发消息失败")
self.messages = []
return True, True, None, None, None
class RandomEmojis(BaseCommand):
command_name = "random_emojis"
command_description = "发送多张随机表情包"
command_pattern = r"^/random_emojis$"
async def execute(self):
emojis = await emoji_api.get_random(5)
if not emojis:
return False, "未找到表情包", False
emoji_base64_list = []
for emoji in emojis:
emoji_base64_list.append(emoji[0])
return await self.forward_images(emoji_base64_list)
async def forward_images(self, images: List[str]):
"""
把多张图片用合并转发的方式发给用户
"""
success = await self.send_forward([("0", "神秘用户", [(ReplyContentType.IMAGE, img)]) for img in images])
return (True, "已发送随机表情包", True) if success else (False, "发送随机表情包失败", False)
# ===== 插件注册 =====
@ -225,6 +245,7 @@ class HelloWorldPlugin(BasePlugin):
(TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
(ForwardMessages.get_handler_info(), ForwardMessages),
(RandomEmojis.get_command_info(), RandomEmojis),
]

View File

@ -4,7 +4,7 @@ import time
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any
from typing import Tuple, List, Dict, Optional, Callable, Any, Set
import traceback
from src.common.logger import get_logger
@ -82,9 +82,7 @@ class LLMRequest:
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_builder.add_image_content(
image_base64=image_base64,
image_format=image_format,
support_formats=client.get_support_image_formats()
image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
)
return [message_builder.build()]
@ -145,7 +143,7 @@ class LLMRequest:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容推理内容模型名称工具调用列表
"""
start_time = time.time()
def message_factory(client: BaseClient) -> List[Message]:
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
@ -177,7 +175,7 @@ class LLMRequest:
endpoint="/chat/completions",
time_cost=time.time() - start_time,
)
return content, (reasoning_content, model_info.name, tool_calls)
return content or "", (reasoning_content, model_info.name, tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""
@ -206,7 +204,7 @@ class LLMRequest:
raise RuntimeError("获取embedding失败")
return embedding, model_info.name
def _select_model(self, exclude_models: set = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
def _select_model(self, exclude_models: Optional[Set[str]] = None) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型
"""
@ -224,7 +222,7 @@ class LLMRequest:
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
force_new_client = (self.request_type == "embedding")
force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
@ -246,13 +244,13 @@ class LLMRequest:
max_tokens: Optional[int],
embedding_input: str | None,
audio_base64: str | None,
compressed_messages: Optional[List[Message]] = None,
) -> APIResponse:
"""
在单个模型上执行请求包含针对临时错误的重试逻辑
如果成功返回APIResponse如果失败重试耗尽或硬错误则抛出ModelAttemptFailed异常
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
while retry_remain > 0:
try:
@ -299,7 +297,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}")
logger.warning(
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
continue
@ -315,8 +315,8 @@ class LLMRequest:
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
except Exception as e:
logger.error(traceback.format_exc())
logger.error(traceback.format_exc())
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
@ -338,12 +338,11 @@ class LLMRequest:
"""
调度器函数负责模型选择故障切换
"""
failed_models_this_request = set()
failed_models_this_request: Set[str] = set()
max_attempts = len(self.model_for_task.model_list)
last_exception: Optional[Exception] = None
compressed_messages: Optional[List[Message]] = None
for _attempt in range(max_attempts):
for _ in range(max_attempts):
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
message_list = []
@ -352,7 +351,10 @@ class LLMRequest:
try:
response = await self._attempt_request_on_model(
model_info, api_provider, client, request_type,
model_info,
api_provider,
client,
request_type,
message_list=message_list,
tool_options=tool_options,
response_format=response_format,
@ -362,7 +364,6 @@ class LLMRequest:
max_tokens=max_tokens,
embedding_input=embedding_input,
audio_base64=audio_base64,
compressed_messages=compressed_messages,
)
return response, model_info
@ -430,4 +431,4 @@ class LLMRequest:
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
reasoning = match[1].strip() if match else ""
return content, reasoning
return content, reasoning