mirror of https://github.com/Mai-with-u/MaiBot.git
🤖 自动格式化代码 [skip ci]
parent
662f92219e
commit
8b837e67bc
|
|
@ -153,7 +153,9 @@ class BotConfig:
|
|||
"用一句话或几句话描述人格的一些侧面",
|
||||
]
|
||||
)
|
||||
personality_detail_level: int = 0 # 人设消息注入 prompt 详细等级 (0: 采用默认配置, 1: 核心/随机细节, 2: 核心+随机侧面/全部细节, 3: 全部)
|
||||
personality_detail_level: int = (
|
||||
0 # 人设消息注入 prompt 详细等级 (0: 采用默认配置, 1: 核心/随机细节, 2: 核心+随机侧面/全部细节, 3: 全部)
|
||||
)
|
||||
# identity
|
||||
identity_detail: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
|
|
@ -179,7 +181,7 @@ class BotConfig:
|
|||
|
||||
base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天
|
||||
base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天
|
||||
allow_remove_duplicates: bool = True # 是否开启心流去重(如果发现心流截断问题严重可尝试关闭)
|
||||
allow_remove_duplicates: bool = True # 是否开启心流去重(如果发现心流截断问题严重可尝试关闭)
|
||||
|
||||
observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
|
||||
|
||||
|
|
@ -244,7 +246,7 @@ class BotConfig:
|
|||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
|
||||
) # 添加新的配置项默认值
|
||||
|
||||
long_message_auto_truncate: bool = True # HFC 模式过长消息自动截断,防止他人 prompt 恶意注入,减少token消耗,但可能损失图片/长文信息,按需选择状态(默认开启)
|
||||
long_message_auto_truncate: bool = True # HFC 模式过长消息自动截断,防止他人 prompt 恶意注入,减少token消耗,但可能损失图片/长文信息,按需选择状态(默认开启)
|
||||
|
||||
# mood
|
||||
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
|
||||
|
|
@ -289,7 +291,6 @@ class BotConfig:
|
|||
min_cooldown: int = 7200 # 最短冷却时间,2小时 (7200秒)
|
||||
max_cooldown: int = 18000 # 最长冷却时间,5小时 (18000秒)
|
||||
|
||||
|
||||
# Group Nickname
|
||||
enable_nickname_mapping: bool = False # 绰号映射功能总开关
|
||||
max_nicknames_in_prompt: int = 10 # Prompt 中最多注入的绰号数量
|
||||
|
|
@ -388,7 +389,9 @@ class BotConfig:
|
|||
config.personality_core = personality_config.get("personality_core", config.personality_core)
|
||||
config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.6.1.2"):
|
||||
config.personality_detail_level = personality_config.get("personality_detail_level", config.personality_sides)
|
||||
config.personality_detail_level = personality_config.get(
|
||||
"personality_detail_level", config.personality_sides
|
||||
)
|
||||
|
||||
def identity(parent: dict):
|
||||
identity_config = parent["identity"]
|
||||
|
|
@ -474,7 +477,9 @@ class BotConfig:
|
|||
for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex):
|
||||
config.ban_msgs_regex.add(re.compile(r))
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.6.1.2"):
|
||||
config.allow_remove_duplicates = chat_config.get("allow_remove_duplicates", config.allow_remove_duplicates)
|
||||
config.allow_remove_duplicates = chat_config.get(
|
||||
"allow_remove_duplicates", config.allow_remove_duplicates
|
||||
)
|
||||
|
||||
def normal_chat(parent: dict):
|
||||
normal_chat_config = parent["normal_chat"]
|
||||
|
|
@ -719,7 +724,9 @@ class BotConfig:
|
|||
if config.INNER_VERSION in SpecifierSet(">=1.1.0"):
|
||||
config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.6.1.5"):
|
||||
config.api_polling_max_retries = experimental_config.get("api_polling_max_retries", config.api_polling_max_retries)
|
||||
config.api_polling_max_retries = experimental_config.get(
|
||||
"api_polling_max_retries", config.api_polling_max_retries
|
||||
)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.6.2"):
|
||||
config.enable_pfc_reply_checker = experimental_config.get(
|
||||
"enable_pfc_reply_checker", config.enable_pfc_reply_checker
|
||||
|
|
|
|||
|
|
@ -254,7 +254,10 @@ class SubMind:
|
|||
|
||||
# 思考指导选项和权重
|
||||
hf_options = [
|
||||
("可以参考之前的想法,在原来想法的基础上继续思考,但是也要注意话题的推进,不要在一个话题上停留太久,除非你觉得真的有必要", 0.3),
|
||||
(
|
||||
"可以参考之前的想法,在原来想法的基础上继续思考,但是也要注意话题的推进,不要在一个话题上停留太久,除非你觉得真的有必要",
|
||||
0.3,
|
||||
),
|
||||
("可以参考之前的想法,在原来的想法上尝试新的话题", 0.3),
|
||||
("不要太深入,注意话题的推进,不要在一个话题上停留太久,除非你觉得真的有必要", 0.2),
|
||||
("进行深入思考,但是注意话题的推进,不要在一个话题上停留太久,除非你觉得真的有必要", 0.2),
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ class Individuality:
|
|||
level = 2
|
||||
elif global_config.personality_detail_level == 3:
|
||||
level = 3
|
||||
else: # level = 0
|
||||
else: # level = 0
|
||||
pass
|
||||
|
||||
# 调用新的独立方法
|
||||
|
|
|
|||
|
|
@ -926,7 +926,7 @@ class HeartFChatting:
|
|||
response_content = json_match.group(0)
|
||||
else:
|
||||
logger.warning(f"LLM 响应似乎不包含有效的 JSON 对象。响应: {response_content}")
|
||||
|
||||
|
||||
cleaned_content = response_content
|
||||
if not cleaned_content:
|
||||
raise json.JSONDecodeError("Cleaned content is empty", cleaned_content, 0)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
import json
|
||||
import random # 添加 random 模块导入
|
||||
import random # 添加 random 模块导入
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union, Dict, Any, Set # 引入 Set
|
||||
from typing import Tuple, Union, Dict, Any, Set # 引入 Set
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client import ClientResponse
|
||||
|
|
@ -36,6 +36,7 @@ logger = get_module_logger("model_utils")
|
|||
|
||||
class PayLoadTooLargeError(Exception):
|
||||
"""自定义异常类,用于处理请求体过大错误"""
|
||||
|
||||
# (代码不变)
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
|
@ -47,6 +48,7 @@ class PayLoadTooLargeError(Exception):
|
|||
|
||||
class RequestAbortException(Exception):
|
||||
"""自定义异常类,用于处理请求中断异常"""
|
||||
|
||||
# (代码不变)
|
||||
def __init__(self, message: str, response: ClientResponse):
|
||||
super().__init__(message)
|
||||
|
|
@ -59,11 +61,12 @@ class RequestAbortException(Exception):
|
|||
|
||||
class PermissionDeniedException(Exception):
|
||||
"""自定义异常类,用于处理访问拒绝的异常"""
|
||||
|
||||
# (代码不变)
|
||||
def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符
|
||||
def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.key_identifier = key_identifier # 存储导致 403 的 key
|
||||
self.key_identifier = key_identifier # 存储导致 403 的 key
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
|
@ -72,6 +75,7 @@ class PermissionDeniedException(Exception):
|
|||
# 新增:用于内部标记需要切换 Key 的异常
|
||||
class _SwitchKeyException(Exception):
|
||||
"""内部异常,用于标记需要切换Key并且跳过标准等待时间."""
|
||||
|
||||
# (代码不变)
|
||||
pass
|
||||
|
||||
|
|
@ -80,9 +84,9 @@ class _SwitchKeyException(Exception):
|
|||
error_code_mapping = {
|
||||
# (代码不变)
|
||||
400: "参数不正确",
|
||||
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效
|
||||
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效
|
||||
402: "账号余额不足",
|
||||
403: "需要实名,或余额不足,或Key无权限", # 扩展 403 的含义
|
||||
403: "需要实名,或余额不足,或Key无权限", # 扩展 403 的含义
|
||||
404: "Not Found",
|
||||
429: "请求过于频繁,请稍后再试",
|
||||
500: "服务器内部故障",
|
||||
|
|
@ -98,15 +102,16 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
|
|||
is_gemini_payload = payload and isinstance(payload, dict) and "contents" in payload
|
||||
safe_payload = json.loads(json.dumps(payload)) if payload else {}
|
||||
|
||||
if (
|
||||
image_base64
|
||||
and safe_payload
|
||||
and isinstance(safe_payload, dict)
|
||||
):
|
||||
if image_base64 and safe_payload and isinstance(safe_payload, dict):
|
||||
if "messages" in safe_payload and len(safe_payload["messages"]) > 0:
|
||||
if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
|
||||
content = safe_payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and isinstance(content[1], dict) and "image_url" in content[1]:
|
||||
if (
|
||||
isinstance(content, list)
|
||||
and len(content) > 1
|
||||
and isinstance(content[1], dict)
|
||||
and "image_url" in content[1]
|
||||
):
|
||||
safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
|
|
@ -116,7 +121,9 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
|
|||
parts = safe_payload["contents"][0]["parts"]
|
||||
for i, part in enumerate(parts):
|
||||
if isinstance(part, dict) and "inlineData" in part:
|
||||
safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = (
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
break
|
||||
|
||||
return safe_payload
|
||||
|
|
@ -125,9 +132,19 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
|
|||
class LLMRequest:
|
||||
# (代码不变)
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
"o1", "o1-2024-12-17", "o1-mini", "o1-mini-2024-09-12", "o1-preview",
|
||||
"o1-preview-2024-09-12", "o1-pro", "o1-pro-2025-03-19", "o3", "o3-2025-04-16",
|
||||
"o3-mini", "o3-mini-2025-01-31o4-mini", "o4-mini-2025-04-16",
|
||||
"o1",
|
||||
"o1-2024-12-17",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-pro",
|
||||
"o1-pro-2025-03-19",
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31o4-mini",
|
||||
"o4-mini-2025-04-16",
|
||||
]
|
||||
_abandoned_keys_runtime: Set[str] = set()
|
||||
|
||||
|
|
@ -171,7 +188,9 @@ class LLMRequest:
|
|||
elif isinstance(raw_api_key_config, str) and raw_api_key_config:
|
||||
parsed_keys = [raw_api_key_config]
|
||||
else:
|
||||
raise ValueError(f"Invalid or empty API key config for {self.model_key_name}: {raw_api_key_config}") from e
|
||||
raise ValueError(
|
||||
f"Invalid or empty API key config for {self.model_key_name}: {raw_api_key_config}"
|
||||
) from e
|
||||
|
||||
if not parsed_keys:
|
||||
raise ValueError(f"No valid API keys found for {self.model_key_name}.")
|
||||
|
|
@ -187,14 +206,20 @@ class LLMRequest:
|
|||
abandoned_keys_set.update(str(key) for key in loaded_abandoned if key)
|
||||
elif isinstance(loaded_abandoned, str) and loaded_abandoned:
|
||||
abandoned_keys_set.add(loaded_abandoned)
|
||||
logger.info(f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' 的废弃 Keys。")
|
||||
logger.info(
|
||||
f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' 的废弃 Keys。"
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
if isinstance(raw_abandoned_keys, list):
|
||||
abandoned_keys_set.update(str(key) for key in raw_abandoned_keys if key)
|
||||
logger.info(f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' (直接列表) 的废弃 Keys。")
|
||||
logger.info(
|
||||
f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' (直接列表) 的废弃 Keys。"
|
||||
)
|
||||
elif isinstance(raw_abandoned_keys, str) and raw_abandoned_keys:
|
||||
abandoned_keys_set.add(raw_abandoned_keys)
|
||||
logger.info(f"模型 {model['name']}: 加载了 1 个来自配置 '{abandoned_key_name}' (字符串) 的废弃 Key。")
|
||||
logger.info(
|
||||
f"模型 {model['name']}: 加载了 1 个来自配置 '{abandoned_key_name}' (字符串) 的废弃 Key。"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"无法解析环境变量 '{abandoned_key_name}' 的内容: {raw_abandoned_keys}")
|
||||
|
||||
|
|
@ -203,18 +228,23 @@ class LLMRequest:
|
|||
|
||||
if not active_keys:
|
||||
logger.error(f"模型 {model['name']}: 所有为 '{self.model_key_name}' 配置的 Keys 都已被废弃或无效。")
|
||||
raise ValueError(f"No active API keys available for {self.model_key_name} after filtering abandoned keys.")
|
||||
raise ValueError(
|
||||
f"No active API keys available for {self.model_key_name} after filtering abandoned keys."
|
||||
)
|
||||
|
||||
if is_list_config and len(active_keys) > 1:
|
||||
self._api_key_config = active_keys
|
||||
logger.info(f"模型 {model['name']}: 初始化完成,可用 Keys: {len(self._api_key_config)} (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。")
|
||||
logger.info(
|
||||
f"模型 {model['name']}: 初始化完成,可用 Keys: {len(self._api_key_config)} (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
|
||||
)
|
||||
elif active_keys:
|
||||
self._api_key_config = active_keys[0]
|
||||
logger.info(f"模型 {model['name']}: 初始化完成,使用单个活动 Key (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。")
|
||||
logger.info(
|
||||
f"模型 {model['name']}: 初始化完成,使用单个活动 Key (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected state: No active keys for {self.model_key_name}.")
|
||||
|
||||
|
||||
# 加载代理配置 (代码不变)
|
||||
self.proxy_url = None
|
||||
self.proxy_models_set = set()
|
||||
|
|
@ -230,11 +260,15 @@ class LLMRequest:
|
|||
|
||||
if proxy_models_str:
|
||||
try:
|
||||
cleaned_str = proxy_models_str.strip('\'"')
|
||||
self.proxy_models_set = {model_name.strip() for model_name in cleaned_str.split(',') if model_name.strip()}
|
||||
cleaned_str = proxy_models_str.strip("'\"")
|
||||
self.proxy_models_set = {
|
||||
model_name.strip() for model_name in cleaned_str.split(",") if model_name.strip()
|
||||
}
|
||||
logger.debug(f"以下模型将使用代理: {self.proxy_models_set}")
|
||||
except Exception as e:
|
||||
logger.error(f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。")
|
||||
logger.error(
|
||||
f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。"
|
||||
)
|
||||
self.proxy_models_set = set()
|
||||
except ValueError:
|
||||
logger.error(f"无效的代理端口号: {proxy_port}。代理将不被启用。")
|
||||
|
|
@ -247,7 +281,6 @@ class LLMRequest:
|
|||
else:
|
||||
logger.info("未配置代理服务器 (PROXY_HOST 或 PROXY_PORT 未设置)。")
|
||||
|
||||
|
||||
except KeyError as e:
|
||||
# (代码不变)
|
||||
missing_key = str(e).strip("'")
|
||||
|
|
@ -340,7 +373,7 @@ class LLMRequest:
|
|||
image_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""配置请求参数,合并实例参数和调用时参数"""
|
||||
default_retry = {
|
||||
|
|
@ -353,14 +386,14 @@ class LLMRequest:
|
|||
|
||||
_actual_endpoint = endpoint
|
||||
if self.is_gemini:
|
||||
action = endpoint.lstrip('/')
|
||||
action = endpoint.lstrip("/")
|
||||
api_url = f"{self.base_url.rstrip('/')}/{self.model_name}{action}"
|
||||
stream_mode = False
|
||||
else:
|
||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
stream_mode = self.stream
|
||||
|
||||
call_params = {k: v for k, v in kwargs.items() if k != 'request_type'}
|
||||
call_params = {k: v for k, v in kwargs.items() if k != "request_type"}
|
||||
merged_params = {**self.params, **call_params}
|
||||
|
||||
if payload is None:
|
||||
|
|
@ -368,11 +401,9 @@ class LLMRequest:
|
|||
else:
|
||||
logger.debug("使用外部提供的 payload,忽略单次调用参数合并。")
|
||||
|
||||
|
||||
if not self.is_gemini and stream_mode:
|
||||
payload["stream"] = merged_params.get("stream", stream_mode)
|
||||
|
||||
|
||||
return {
|
||||
"policy": policy,
|
||||
"payload": payload,
|
||||
|
|
@ -394,11 +425,11 @@ class LLMRequest:
|
|||
response_handler: callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""统一请求执行入口, 支持列表 key 切换、代理和单次调用参数覆盖"""
|
||||
final_request_type = request_type or kwargs.get('request_type') or self.request_type
|
||||
api_kwargs = {k: v for k, v in kwargs.items() if k != 'request_type'}
|
||||
final_request_type = request_type or kwargs.get("request_type") or self.request_type
|
||||
api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
|
||||
|
||||
request_content = await self._prepare_request(
|
||||
endpoint, prompt, image_base64, image_format, payload, retry_policy, **api_kwargs
|
||||
|
|
@ -436,7 +467,9 @@ class LLMRequest:
|
|||
random.shuffle(available_keys_pool)
|
||||
key_switch_limit_429 = min(key_switch_limit_429, len(available_keys_pool))
|
||||
key_switch_limit_403 = min(key_switch_limit_403, len(available_keys_pool))
|
||||
logger.info(f"模型 {self.model_name}: Key 列表模式,启用 429/403 自动切换(429上限: {key_switch_limit_429}, 403上限: {key_switch_limit_403})。")
|
||||
logger.info(
|
||||
f"模型 {self.model_name}: Key 列表模式,启用 429/403 自动切换(429上限: {key_switch_limit_429}, 403上限: {key_switch_limit_403})。"
|
||||
)
|
||||
elif isinstance(self._api_key_config, str):
|
||||
available_keys_pool = [self._api_key_config]
|
||||
key_switch_limit_429 = 1
|
||||
|
|
@ -451,9 +484,22 @@ class LLMRequest:
|
|||
if available_keys_pool:
|
||||
current_key = available_keys_pool.pop(0)
|
||||
elif current_key:
|
||||
logger.debug(f"模型 {self.model_name}: 无新 Key 可用或为单 Key 模式,将使用 Key ...{current_key[-4:]} 进行重试 (第 {attempt + 1} 次尝试)")
|
||||
logger.debug(
|
||||
f"模型 {self.model_name}: 无新 Key 可用或为单 Key 模式,将使用 Key ...{current_key[-4:]} 进行重试 (第 {attempt + 1} 次尝试)"
|
||||
)
|
||||
else:
|
||||
if not self._api_key_config or all(k in LLMRequest._abandoned_keys_runtime for k in self._api_key_config if isinstance(self._api_key_config, list)) or (isinstance(self._api_key_config, str) and self._api_key_config in LLMRequest._abandoned_keys_runtime):
|
||||
if (
|
||||
not self._api_key_config
|
||||
or all(
|
||||
k in LLMRequest._abandoned_keys_runtime
|
||||
for k in self._api_key_config
|
||||
if isinstance(self._api_key_config, list)
|
||||
)
|
||||
or (
|
||||
isinstance(self._api_key_config, str)
|
||||
and self._api_key_config in LLMRequest._abandoned_keys_runtime
|
||||
)
|
||||
):
|
||||
final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
|
||||
logger.critical(final_error_msg)
|
||||
raise PermissionDeniedException(final_error_msg)
|
||||
|
|
@ -468,11 +514,7 @@ class LLMRequest:
|
|||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
post_kwargs = {
|
||||
"headers": headers,
|
||||
"json": actual_payload,
|
||||
"timeout": 60
|
||||
}
|
||||
post_kwargs = {"headers": headers, "json": actual_payload, "timeout": 60}
|
||||
if use_proxy:
|
||||
post_kwargs["proxy"] = current_proxy_url
|
||||
|
||||
|
|
@ -480,10 +522,14 @@ class LLMRequest:
|
|||
if response.status == 429 and is_key_list:
|
||||
logger.warning(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 429 错误。")
|
||||
response_text = await response.text()
|
||||
logger.debug(f"模型 {self.model_name}: Key ...{current_key[-4:]} response:\n{json.dumps(json.loads(response_text), indent=2, ensure_ascii=False)}\napi_url:\n{api_url}\nheader:\n{headers}\npayload:\n{actual_payload}")
|
||||
logger.debug(
|
||||
f"模型 {self.model_name}: Key ...{current_key[-4:]} response:\n{json.dumps(json.loads(response_text), indent=2, ensure_ascii=False)}\napi_url:\n{api_url}\nheader:\n{headers}\npayload:\n{actual_payload}"
|
||||
)
|
||||
if current_key not in keys_failed_429:
|
||||
keys_failed_429.add(current_key)
|
||||
logger.info(f" (因 429 已失败 {len(keys_failed_429)}/{key_switch_limit_429} 个不同 Key)")
|
||||
logger.info(
|
||||
f" (因 429 已失败 {len(keys_failed_429)}/{key_switch_limit_429} 个不同 Key)"
|
||||
)
|
||||
if available_keys_pool and len(keys_failed_429) < key_switch_limit_429:
|
||||
logger.info(" 尝试因 429 切换到下一个可用 Key...")
|
||||
raise _SwitchKeyException()
|
||||
|
|
@ -493,11 +539,15 @@ class LLMRequest:
|
|||
logger.warning(f" Key ...{current_key[-4:]} 再次遇到 429,按标准重试流程。")
|
||||
|
||||
elif response.status == 403 and is_key_list:
|
||||
logger.error(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 403 (权限拒绝) 错误。")
|
||||
logger.error(
|
||||
f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 403 (权限拒绝) 错误。"
|
||||
)
|
||||
if current_key not in keys_abandoned_runtime:
|
||||
keys_abandoned_runtime.add(current_key)
|
||||
LLMRequest._abandoned_keys_runtime.add(current_key)
|
||||
logger.critical(f" !! Key ...{current_key[-4:]} 已添加到运行时废弃列表。请考虑将其移至配置中的 'abandon_{self.model_key_name}' !!")
|
||||
logger.critical(
|
||||
f" !! Key ...{current_key[-4:]} 已添加到运行时废弃列表。请考虑将其移至配置中的 'abandon_{self.model_key_name}' !!"
|
||||
)
|
||||
if current_key in available_keys_pool:
|
||||
available_keys_pool.remove(current_key)
|
||||
if available_keys_pool and len(keys_abandoned_runtime) < key_switch_limit_403:
|
||||
|
|
@ -506,11 +556,16 @@ class LLMRequest:
|
|||
else:
|
||||
logger.error(" 无更多 Key 可因 403 切换或已达上限。将中止请求。")
|
||||
await response.read()
|
||||
raise PermissionDeniedException(f"Key ...{current_key[-4:]} 权限被拒,且无其他可用 Key 切换。", key_identifier=current_key)
|
||||
raise PermissionDeniedException(
|
||||
f"Key ...{current_key[-4:]} 权限被拒,且无其他可用 Key 切换。",
|
||||
key_identifier=current_key,
|
||||
)
|
||||
else:
|
||||
logger.error(f" Key ...{current_key[-4:]} 再次遇到 403,这不应发生。中止请求。")
|
||||
await response.read()
|
||||
raise PermissionDeniedException(f"Key ...{current_key[-4:]} 重复遇到 403。", key_identifier=current_key)
|
||||
raise PermissionDeniedException(
|
||||
f"Key ...{current_key[-4:]} 重复遇到 403。", key_identifier=current_key
|
||||
)
|
||||
|
||||
elif response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
|
||||
await self._handle_error_response(response, attempt, policy, current_key)
|
||||
|
|
@ -518,14 +573,20 @@ class LLMRequest:
|
|||
if response.status in policy["retry_codes"] and attempt < policy["max_retries"] - 1:
|
||||
if response.status not in [429, 403]:
|
||||
wait_time = policy["base_wait"] * (2**attempt)
|
||||
logger.warning(f"模型 {self.model_name}: 遇到可重试错误 {response.status}, 等待 {wait_time} 秒后重试...")
|
||||
logger.warning(
|
||||
f"模型 {self.model_name}: 遇到可重试错误 {response.status}, 等待 {wait_time} 秒后重试..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
last_exception = RuntimeError(f"重试错误 {response.status}")
|
||||
continue
|
||||
|
||||
if response.status in policy["abort_codes"] or (response.status in policy["retry_codes"] and attempt >= policy["max_retries"] - 1):
|
||||
if response.status in policy["abort_codes"] or (
|
||||
response.status in policy["retry_codes"] and attempt >= policy["max_retries"] - 1
|
||||
):
|
||||
if attempt >= policy["max_retries"] - 1 and response.status in policy["retry_codes"]:
|
||||
logger.error(f"模型 {self.model_name}: 达到最大重试次数,最后一次尝试仍为可重试错误 {response.status}。")
|
||||
logger.error(
|
||||
f"模型 {self.model_name}: 达到最大重试次数,最后一次尝试仍为可重试错误 {response.status}。"
|
||||
)
|
||||
# await self._handle_error_response(response, attempt, policy, current_key)
|
||||
# await response.read()
|
||||
# final_error_msg = f"请求中止或达到最大重试次数,最终状态码: {response.status}"
|
||||
|
|
@ -579,10 +640,14 @@ class LLMRequest:
|
|||
except Exception as e:
|
||||
# (代码不变)
|
||||
last_exception = e
|
||||
logger.warning(f"模型 {self.model_name}: 第 {attempt + 1} 次尝试中发生非 HTTP 错误: {str(e.__class__.__name__)} - {str(e)}")
|
||||
logger.warning(
|
||||
f"模型 {self.model_name}: 第 {attempt + 1} 次尝试中发生非 HTTP 错误: {str(e.__class__.__name__)} - {str(e)}"
|
||||
)
|
||||
|
||||
if attempt >= policy["max_retries"] - 1:
|
||||
logger.error(f"模型 {self.model_name}: 达到最大重试次数 ({policy['max_retries']}),因非 HTTP 错误失败。")
|
||||
logger.error(
|
||||
f"模型 {self.model_name}: 达到最大重试次数 ({policy['max_retries']}),因非 HTTP 错误失败。"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
temp_request_content = {
|
||||
|
|
@ -620,7 +685,9 @@ class LLMRequest:
|
|||
logger.error(f"最后遇到的错误是权限拒绝: {str(last_exception)}")
|
||||
raise last_exception
|
||||
logger.error(f"最后遇到的错误: {str(last_exception.__class__.__name__)} - {str(last_exception)}")
|
||||
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API 请求失败。最后错误: {str(last_exception)}") from last_exception
|
||||
raise RuntimeError(
|
||||
f"模型 {self.model_name} 达到最大重试次数,API 请求失败。最后错误: {str(last_exception)}"
|
||||
) from last_exception
|
||||
else:
|
||||
if not available_keys_pool and keys_abandoned_runtime:
|
||||
final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
|
||||
|
|
@ -629,7 +696,6 @@ class LLMRequest:
|
|||
else:
|
||||
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API 请求失败,原因未知。")
|
||||
|
||||
|
||||
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
|
||||
"""处理 OpenAI 兼容的流式输出"""
|
||||
# (代码不变)
|
||||
|
|
@ -667,20 +733,28 @@ class LLMRequest:
|
|||
tool_calls = []
|
||||
for tc in delta["tool_calls"]:
|
||||
new_tc = dict(tc)
|
||||
if 'function' in new_tc and 'arguments' not in new_tc['function']:
|
||||
new_tc['function']['arguments'] = ""
|
||||
if "function" in new_tc and "arguments" not in new_tc["function"]:
|
||||
new_tc["function"]["arguments"] = ""
|
||||
tool_calls.append(new_tc)
|
||||
else:
|
||||
for i, tc_delta in enumerate(delta["tool_calls"]):
|
||||
if i < len(tool_calls) and 'function' in tc_delta and 'arguments' in tc_delta['function']:
|
||||
if 'arguments' in tool_calls[i]['function']:
|
||||
tool_calls[i]['function']['arguments'] += tc_delta['function']['arguments']
|
||||
if (
|
||||
i < len(tool_calls)
|
||||
and "function" in tc_delta
|
||||
and "arguments" in tc_delta["function"]
|
||||
):
|
||||
if "arguments" in tool_calls[i]["function"]:
|
||||
tool_calls[i]["function"]["arguments"] += tc_delta["function"][
|
||||
"arguments"
|
||||
]
|
||||
else:
|
||||
tool_calls[i]['function']['arguments'] = tc_delta['function']['arguments']
|
||||
tool_calls[i]["function"]["arguments"] = tc_delta["function"][
|
||||
"arguments"
|
||||
]
|
||||
elif i >= len(tool_calls):
|
||||
new_tc = dict(tc_delta)
|
||||
if 'function' in new_tc and 'arguments' not in new_tc['function']:
|
||||
new_tc['function']['arguments'] = ""
|
||||
if "function" in new_tc and "arguments" not in new_tc["function"]:
|
||||
new_tc["function"]["arguments"] = ""
|
||||
tool_calls.append(new_tc)
|
||||
|
||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||
|
|
@ -751,7 +825,9 @@ class LLMRequest:
|
|||
|
||||
elif status in policy["retry_codes"] and status != 429:
|
||||
if status == 413:
|
||||
logger.warning(f"模型 {self.model_name}: 错误码 413 (Payload Too Large)。Key: ...{current_key[-4:] if current_key else 'N/A'}. 尝试压缩...")
|
||||
logger.warning(
|
||||
f"模型 {self.model_name}: 错误码 413 (Payload Too Large)。Key: ...{current_key[-4:] if current_key else 'N/A'}. 尝试压缩..."
|
||||
)
|
||||
raise PayLoadTooLargeError("请求体过大")
|
||||
elif status in [500, 503]:
|
||||
logger.error(
|
||||
|
|
@ -760,7 +836,9 @@ class LLMRequest:
|
|||
)
|
||||
return
|
||||
else:
|
||||
logger.warning(f"模型 {self.model_name}: 遇到可重试错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}")
|
||||
logger.warning(
|
||||
f"模型 {self.model_name}: 遇到可重试错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}"
|
||||
)
|
||||
return
|
||||
|
||||
elif status in policy["abort_codes"]:
|
||||
|
|
@ -770,14 +848,15 @@ class LLMRequest:
|
|||
)
|
||||
raise RequestAbortException(f"请求出现错误 {status},中止处理", response)
|
||||
else:
|
||||
logger.error(f"模型 {self.model_name}: 遇到未明确处理的错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}")
|
||||
logger.error(
|
||||
f"模型 {self.model_name}: 遇到未明确处理的错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
raise RequestAbortException(f"未处理的错误状态码 {status}", response)
|
||||
except aiohttp.ClientResponseError as e:
|
||||
raise RequestAbortException(f"未处理的错误状态码 {status}: {e.message}", response) from e
|
||||
|
||||
|
||||
async def _handle_exception(
|
||||
self, exception, retry_count: int, request_content: Dict[str, Any], merged_params: Dict[str, Any] = None
|
||||
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
|
||||
|
|
@ -799,7 +878,10 @@ class LLMRequest:
|
|||
compressed_image_base64 = compress_base64_image_by_scale(image_base64)
|
||||
if compressed_image_base64 != image_base64:
|
||||
new_payload = await self._build_payload(
|
||||
request_content["prompt"], compressed_image_base64, request_content["image_format"], params_for_rebuild
|
||||
request_content["prompt"],
|
||||
compressed_image_base64,
|
||||
request_content["image_format"],
|
||||
params_for_rebuild,
|
||||
)
|
||||
logger.info("图片压缩成功,将使用压缩后的图片重试。")
|
||||
return new_payload, 0
|
||||
|
|
@ -826,7 +908,7 @@ class LLMRequest:
|
|||
f"模型 {self.model_name} HTTP响应错误 (未被策略覆盖): 状态码: {exception.status}, 错误: {exception.message}"
|
||||
)
|
||||
try:
|
||||
error_text = await exception.response.text() if hasattr(exception, 'response') else str(exception)
|
||||
error_text = await exception.response.text() if hasattr(exception, "response") else str(exception)
|
||||
logger.error(f"服务器错误响应详情: {error_text[:500]}")
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"无法解析服务器错误响应内容: {str(parse_err)}")
|
||||
|
|
@ -837,23 +919,30 @@ class LLMRequest:
|
|||
)
|
||||
current_key_placeholder = request_content.get("current_key", "******")
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}")
|
||||
logger.critical(
|
||||
f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||
) from exception
|
||||
|
||||
else:
|
||||
if keep_request:
|
||||
logger.error(f"模型 {self.model_name} 遇到未知错误: {str(exception.__class__.__name__)} - {str(exception)}")
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 遇到未知错误: {str(exception.__class__.__name__)} - {str(exception)}"
|
||||
)
|
||||
return None, 0
|
||||
else:
|
||||
logger.critical(f"模型 {self.model_name} 请求因未知错误失败: {str(exception.__class__.__name__)} - {str(exception)}")
|
||||
logger.critical(
|
||||
f"模型 {self.model_name} 请求因未知错误失败: {str(exception.__class__.__name__)} - {str(exception)}"
|
||||
)
|
||||
current_key_placeholder = request_content.get("current_key", "******")
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}")
|
||||
logger.critical(
|
||||
f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
|
||||
)
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") from exception
|
||||
|
||||
|
||||
async def _transform_parameters(self, merged_params: dict) -> dict:
|
||||
"""根据模型名称转换合并后的参数,并移除内部参数"""
|
||||
# (代码不变)
|
||||
|
|
@ -884,7 +973,9 @@ class LLMRequest:
|
|||
|
||||
return new_params
|
||||
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None, merged_params: dict = None) -> dict:
|
||||
async def _build_payload(
|
||||
self, prompt: str, image_base64: str = None, image_format: str = None, merged_params: dict = None
|
||||
) -> dict:
|
||||
"""构建请求体 (区分 Gemini 和 OpenAI),使用合并和转换后的参数"""
|
||||
# (代码不变)
|
||||
if merged_params is None:
|
||||
|
|
@ -898,16 +989,8 @@ class LLMRequest:
|
|||
parts.append({"text": prompt})
|
||||
if image_base64:
|
||||
mime_type = f"image/{image_format.lower() if image_format else 'jpeg'}"
|
||||
parts.append({
|
||||
"inlineData": {
|
||||
"mimeType": mime_type,
|
||||
"data": image_base64
|
||||
}
|
||||
})
|
||||
payload = {
|
||||
"contents": [{"parts": parts}],
|
||||
**params_copy
|
||||
}
|
||||
parts.append({"inlineData": {"mimeType": mime_type, "data": image_base64}})
|
||||
payload = {"contents": [{"parts": parts}], **params_copy}
|
||||
payload.pop("model", None)
|
||||
# --- 添加 Gemini 安全设置 ---
|
||||
safety_settings = [
|
||||
|
|
@ -930,7 +1013,9 @@ class LLMRequest:
|
|||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64}"},
|
||||
"image_url": {
|
||||
"url": f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
|
@ -957,8 +1042,8 @@ class LLMRequest:
|
|||
"""默认响应解析 (区分 Gemini 和 OpenAI),并处理函数/工具调用"""
|
||||
content = "没有返回结果"
|
||||
reasoning_content = ""
|
||||
tool_calls = None # OpenAI 格式
|
||||
function_call = None # Gemini 格式
|
||||
tool_calls = None # OpenAI 格式
|
||||
function_call = None # Gemini 格式
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
|
@ -974,13 +1059,13 @@ class LLMRequest:
|
|||
final_text_parts = []
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "functionCall" in part:
|
||||
function_call = part["functionCall"] # 获取 Gemini 的 functionCall
|
||||
function_call = part["functionCall"] # 获取 Gemini 的 functionCall
|
||||
# Gemini functionCall 通常不与 text 一起返回,这里假设只处理 functionCall
|
||||
break # 找到 functionCall 就停止处理 parts
|
||||
break # 找到 functionCall 就停止处理 parts
|
||||
elif "text" in part:
|
||||
final_text_parts.append(part.get("text", ""))
|
||||
|
||||
if not function_call: # 如果没有 functionCall,处理 text
|
||||
if not function_call: # 如果没有 functionCall,处理 text
|
||||
raw_content = "".join(final_text_parts).strip()
|
||||
content, reasoning = self._extract_reasoning(raw_content)
|
||||
reasoning_content = reasoning
|
||||
|
|
@ -1028,7 +1113,7 @@ class LLMRequest:
|
|||
explicit_reasoning = message.get("reasoning_content", "")
|
||||
reasoning_content = explicit_reasoning if explicit_reasoning else reasoning
|
||||
|
||||
tool_calls = message.get("tool_calls", None) # 获取 OpenAI 的 tool_calls
|
||||
tool_calls = message.get("tool_calls", None) # 获取 OpenAI 的 tool_calls
|
||||
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
|
|
@ -1054,26 +1139,25 @@ class LLMRequest:
|
|||
else:
|
||||
logger.warning(f"模型 {self.model_name}: 未能从响应中提取有效的 token 使用信息。")
|
||||
|
||||
|
||||
# --- 返回结果 (统一格式) ---
|
||||
final_tool_calls = None
|
||||
if tool_calls: # 来自 OpenAI
|
||||
if tool_calls: # 来自 OpenAI
|
||||
final_tool_calls = tool_calls
|
||||
logger.debug(f"检测到 OpenAI 工具调用: {final_tool_calls}")
|
||||
elif function_call: # 来自 Gemini
|
||||
elif function_call: # 来自 Gemini
|
||||
logger.debug(f"检测到 Gemini 函数调用: {function_call}")
|
||||
# 将 Gemini functionCall 转换为 OpenAI tool_calls 格式
|
||||
# 注意: Gemini 的 functionCall 没有显式的 id 和 type,需要模拟
|
||||
final_tool_calls = [
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 9999)}", # 生成一个随机 ID
|
||||
"id": f"call_{random.randint(1000, 9999)}", # 生成一个随机 ID
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_call.get("name"),
|
||||
# Gemini 的参数在 'args' 中,OpenAI 在 'arguments' (通常是 JSON 字符串)
|
||||
# 需要将 Gemini 的 dict 参数转换为 JSON 字符串
|
||||
"arguments": json.dumps(function_call.get("args", {}))
|
||||
}
|
||||
"arguments": json.dumps(function_call.get("args", {})),
|
||||
},
|
||||
}
|
||||
]
|
||||
logger.debug(f"转换为 OpenAI tool_calls 格式: {final_tool_calls}")
|
||||
|
|
@ -1085,7 +1169,6 @@ class LLMRequest:
|
|||
# 没有工具/函数调用,返回普通文本响应
|
||||
return content, reasoning_content
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
"""CoT思维链提取"""
|
||||
|
|
@ -1107,7 +1190,7 @@ class LLMRequest:
|
|||
if self.is_gemini:
|
||||
return {"x-goog-api-key": "**********", "Content-Type": "application/json"}
|
||||
else:
|
||||
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
|
||||
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
|
||||
else:
|
||||
if not api_key:
|
||||
logger.error(f"尝试使用无效 (空) 的 API key 为模型 {self.model_name} 构建请求头。")
|
||||
|
|
@ -1118,16 +1201,11 @@ class LLMRequest:
|
|||
else:
|
||||
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
|
||||
async def generate_response(self, prompt: str, user_id: str = "system", **kwargs) -> Tuple:
|
||||
"""根据输入的提示生成模型的异步响应,支持覆盖参数"""
|
||||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||||
response = await self._execute_request(
|
||||
endpoint=endpoint,
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
request_type="chat",
|
||||
**kwargs
|
||||
endpoint=endpoint, prompt=prompt, user_id=user_id, request_type="chat", **kwargs
|
||||
)
|
||||
if len(response) == 3:
|
||||
content, reasoning_content, tool_calls = response
|
||||
|
|
@ -1136,7 +1214,9 @@ class LLMRequest:
|
|||
content, reasoning_content = response
|
||||
return content, reasoning_content, self.model_name
|
||||
|
||||
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str, user_id: str = "system", **kwargs) -> Tuple:
|
||||
async def generate_response_for_image(
|
||||
self, prompt: str, image_base64: str, image_format: str, user_id: str = "system", **kwargs
|
||||
) -> Tuple:
|
||||
"""根据输入的提示和图片生成模型的异步响应,支持覆盖参数"""
|
||||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||||
response = await self._execute_request(
|
||||
|
|
@ -1146,42 +1226,45 @@ class LLMRequest:
|
|||
image_format=image_format,
|
||||
user_id=user_id,
|
||||
request_type="vision",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
# _default_response_handler 现在总是返回至少2个值
|
||||
if len(response) == 3:
|
||||
return response # content, reasoning, tool_calls (tool_calls 可能为 None)
|
||||
return response # content, reasoning, tool_calls (tool_calls 可能为 None)
|
||||
elif len(response) == 2:
|
||||
content, reasoning = response
|
||||
return content, reasoning # 对于 vision 请求,通常没有 tool_calls
|
||||
return content, reasoning # 对于 vision 请求,通常没有 tool_calls
|
||||
else:
|
||||
logger.error(f"来自 _default_response_handler 的意外响应格式: {response}")
|
||||
return "处理响应出错", ""
|
||||
|
||||
|
||||
async def generate_response_async(self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs) -> Union[str, Tuple]:
|
||||
async def generate_response_async(
|
||||
self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs
|
||||
) -> Union[str, Tuple]:
|
||||
"""异步方式根据输入的提示生成模型的响应 (通用),支持覆盖参数"""
|
||||
# (代码不变)
|
||||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||||
response = await self._execute_request(
|
||||
endpoint=endpoint,
|
||||
prompt=prompt,
|
||||
payload=None,
|
||||
retry_policy=None,
|
||||
response_handler=None,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
**kwargs
|
||||
endpoint=endpoint,
|
||||
prompt=prompt,
|
||||
payload=None,
|
||||
retry_policy=None,
|
||||
response_handler=None,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
# 修改:实现 Gemini Function Calling 的 Payload 构建
|
||||
async def generate_response_tool_async(self, prompt: str, tools: list, user_id: str = "system", **kwargs) -> tuple[str, str, list | None]:
|
||||
async def generate_response_tool_async(
|
||||
self, prompt: str, tools: list, user_id: str = "system", **kwargs
|
||||
) -> tuple[str, str, list | None]:
|
||||
"""异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数和 Gemini 函数调用"""
|
||||
|
||||
endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
|
||||
merged_params = {**self.params, **kwargs}
|
||||
transformed_params = await self._transform_parameters(merged_params) # 清理 request_type 等
|
||||
transformed_params = await self._transform_parameters(merged_params) # 清理 request_type 等
|
||||
|
||||
payload = None
|
||||
|
||||
|
|
@ -1197,36 +1280,39 @@ class LLMRequest:
|
|||
if tool.get("type") == "function" and "function" in tool:
|
||||
func_def = tool["function"]
|
||||
# Gemini parameters 使用 OpenAPI Schema,与 OpenAI 基本兼容
|
||||
function_declarations.append({
|
||||
"name": func_def.get("name"),
|
||||
"description": func_def.get("description", ""), # Description is required for Gemini
|
||||
"parameters": func_def.get("parameters", {"type": "object", "properties": {}}) # Ensure parameters exist
|
||||
})
|
||||
function_declarations.append(
|
||||
{
|
||||
"name": func_def.get("name"),
|
||||
"description": func_def.get("description", ""), # Description is required for Gemini
|
||||
"parameters": func_def.get(
|
||||
"parameters", {"type": "object", "properties": {}}
|
||||
), # Ensure parameters exist
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"跳过不支持的工具类型或格式: {tool}")
|
||||
|
||||
if not function_declarations:
|
||||
logger.error("没有有效的函数声明可用于 Gemini 请求。")
|
||||
return "没有提供有效的函数定义", "", None
|
||||
logger.error("没有有效的函数声明可用于 Gemini 请求。")
|
||||
return "没有提供有效的函数定义", "", None
|
||||
|
||||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||
|
||||
# 2. 构建 Gemini Payload
|
||||
# parts = [{"text": prompt}] # 初始 parts
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": prompt}]}], # 包含用户提示
|
||||
"contents": [{"parts": [{"text": prompt}]}], # 包含用户提示
|
||||
"tools": gemini_tools,
|
||||
# toolConfig 默认是 AUTO,可以根据需要从 kwargs 获取或硬编码
|
||||
# "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, # 例如强制调用
|
||||
**transformed_params # 合并其他转换后的参数 (如 generationConfig)
|
||||
**transformed_params, # 合并其他转换后的参数 (如 generationConfig)
|
||||
}
|
||||
payload.pop("model", None) # Gemini 不在顶层传 model
|
||||
payload.pop("messages", None) # 移除 OpenAI 特有的 messages
|
||||
payload.pop("tool_choice", None) # 移除 OpenAI 特有的 tool_choice
|
||||
payload.pop("model", None) # Gemini 不在顶层传 model
|
||||
payload.pop("messages", None) # 移除 OpenAI 特有的 messages
|
||||
payload.pop("tool_choice", None) # 移除 OpenAI 特有的 tool_choice
|
||||
|
||||
logger.trace(f"构建的 Gemini 函数调用 Payload: {json.dumps(payload, indent=2)}")
|
||||
|
||||
|
||||
else:
|
||||
# --- 构建 OpenAI Tool Calling Payload ---
|
||||
# (逻辑不变)
|
||||
|
|
@ -1239,23 +1325,22 @@ class LLMRequest:
|
|||
"tool_choice": transformed_params.get("tool_choice", "auto"),
|
||||
}
|
||||
if "max_completion_tokens" in payload:
|
||||
payload["max_tokens"] = payload.pop("max_completion_tokens")
|
||||
payload["max_tokens"] = payload.pop("max_completion_tokens")
|
||||
if "max_tokens" not in payload:
|
||||
payload["max_tokens"] = global_config.model_max_output_length
|
||||
|
||||
payload["max_tokens"] = global_config.model_max_output_length
|
||||
|
||||
# --- 执行请求 ---
|
||||
if payload is None:
|
||||
logger.error("未能构建有效的 API 请求 payload。")
|
||||
return "内部错误:无法构建请求", "", None
|
||||
logger.error("未能构建有效的 API 请求 payload。")
|
||||
return "内部错误:无法构建请求", "", None
|
||||
|
||||
response = await self._execute_request(
|
||||
endpoint=endpoint,
|
||||
payload=payload,
|
||||
prompt=prompt, # prompt 仍然需要,用于可能的重试
|
||||
prompt=prompt, # prompt 仍然需要,用于可能的重试
|
||||
user_id=user_id,
|
||||
request_type="tool_call",
|
||||
**kwargs # 传递原始 kwargs 以便在重试时重新合并
|
||||
**kwargs, # 传递原始 kwargs 以便在重试时重新合并
|
||||
)
|
||||
|
||||
# _default_response_handler 现在会处理 Gemini functionCall 并统一格式
|
||||
|
|
@ -1273,7 +1358,6 @@ class LLMRequest:
|
|||
logger.error(f"收到来自 _execute_request/_default_response_handler 的意外响应格式: {response}")
|
||||
return "处理响应时出错", "", None
|
||||
|
||||
|
||||
async def get_embedding(self, text: str, user_id: str = "system", **kwargs) -> Union[list, None]:
|
||||
"""异步方法:获取文本的embedding向量,支持覆盖参数 (Gemini Embedding 需注意模型名称)"""
|
||||
# (代码不变)
|
||||
|
|
@ -1281,32 +1365,20 @@ class LLMRequest:
|
|||
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
|
||||
return None
|
||||
|
||||
api_kwargs = {k: v for k, v in kwargs.items() if k != 'request_type'}
|
||||
api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
|
||||
|
||||
if self.is_gemini:
|
||||
endpoint = ":embedContent"
|
||||
payload = {
|
||||
"model": f"models/{self.model_name}",
|
||||
"content": {
|
||||
"parts": [{"text": text}]
|
||||
},
|
||||
**api_kwargs
|
||||
}
|
||||
payload = {"model": f"models/{self.model_name}", "content": {"parts": [{"text": text}]}, **api_kwargs}
|
||||
payload.pop("encoding_format", None)
|
||||
payload.pop("input", None)
|
||||
|
||||
else:
|
||||
endpoint = "/embeddings"
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"encoding_format": "float",
|
||||
**api_kwargs
|
||||
}
|
||||
payload = {"model": self.model_name, "input": text, "encoding_format": "float", **api_kwargs}
|
||||
payload.pop("content", None)
|
||||
payload.pop("taskType", None)
|
||||
|
||||
|
||||
def embedding_handler(result):
|
||||
# (代码不变)
|
||||
embedding_value = None
|
||||
|
|
@ -1346,7 +1418,7 @@ class LLMRequest:
|
|||
response_handler=embedding_handler,
|
||||
user_id=user_id,
|
||||
request_type="embedding",
|
||||
**api_kwargs
|
||||
**api_kwargs,
|
||||
)
|
||||
return embedding
|
||||
|
||||
|
|
@ -1366,33 +1438,40 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
|
|||
new_width = max(1, int(original_width * scale))
|
||||
new_height = max(1, int(original_height * scale))
|
||||
output_buffer = io.BytesIO()
|
||||
save_format = img_format # Default to original format
|
||||
save_format = img_format # Default to original format
|
||||
|
||||
if getattr(img, "is_animated", False) and img.n_frames > 1:
|
||||
frames = []
|
||||
durations = []
|
||||
loop = img.info.get('loop', 0)
|
||||
disposal = img.info.get('disposal', 2)
|
||||
loop = img.info.get("loop", 0)
|
||||
disposal = img.info.get("disposal", 2)
|
||||
logger.info(f"检测到 GIF 动图 ({img.n_frames} 帧),尝试按比例压缩...")
|
||||
for frame_idx in range(img.n_frames):
|
||||
img.seek(frame_idx)
|
||||
current_duration = img.info.get('duration', 100)
|
||||
current_duration = img.info.get("duration", 100)
|
||||
durations.append(current_duration)
|
||||
new_frame = img.convert("RGBA").copy()
|
||||
resized_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
frames.append(resized_frame)
|
||||
if frames:
|
||||
frames[0].save(
|
||||
output_buffer, format="GIF", save_all=True, append_images=frames[1:],
|
||||
optimize=False, duration=durations, loop=loop, disposal=disposal,
|
||||
transparency=img.info.get('transparency', None), background=img.info.get('background', None)
|
||||
output_buffer,
|
||||
format="GIF",
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=False,
|
||||
duration=durations,
|
||||
loop=loop,
|
||||
disposal=disposal,
|
||||
transparency=img.info.get("transparency", None),
|
||||
background=img.info.get("background", None),
|
||||
)
|
||||
save_format = "GIF"
|
||||
else:
|
||||
logger.warning("未能处理 GIF 帧。")
|
||||
return base64_data
|
||||
else:
|
||||
if img.mode in ("RGBA", "LA") or 'transparency' in img.info:
|
||||
if img.mode in ("RGBA", "LA") or "transparency" in img.info:
|
||||
resized_img = img.convert("RGBA").resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
save_format = "PNG"
|
||||
save_params = {"optimize": True}
|
||||
|
|
@ -1407,8 +1486,12 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
|
|||
resized_img.save(output_buffer, format=save_format, **save_params)
|
||||
|
||||
compressed_data = output_buffer.getvalue()
|
||||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height} ({img.format} -> {save_format})")
|
||||
logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB (目标: {target_size / 1024:.1f}KB)")
|
||||
logger.success(
|
||||
f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height} ({img.format} -> {save_format})"
|
||||
)
|
||||
logger.info(
|
||||
f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB (目标: {target_size / 1024:.1f}KB)"
|
||||
)
|
||||
if len(compressed_data) < len(image_data) * 0.95:
|
||||
return base64.b64encode(compressed_data).decode("utf-8")
|
||||
else:
|
||||
|
|
@ -1417,6 +1500,6 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
|
|||
except Exception as e:
|
||||
logger.error(f"压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return base64_data
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue