From a99f3b81b452a2acc43fbb533d42de5dc4140f50 Mon Sep 17 00:00:00 2001
From: Slapq <106904191+Slapq@users.noreply.github.com>
Date: Sun, 16 Mar 2025 08:15:34 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8Grox=E4=BF=AE=E5=A4=8D?=
=?UTF-8?q?=E7=9A=84=E6=A8=A1=E5=9E=8B=E8=B0=83=E7=94=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
这让我的oneapi能用了,同时修复了一些问题,详细见下
修复重点
流式输出处理
原代码中流式模式的正则表达式 re.search(r'', content, re.DOTALL) 和 re.sub(r'', '', content) 使用空模式,导致无法正确解析内容。修复为直接累积 delta 中的 content,并确保 usage 数据被正确捕获。
接口调用稳定性
增强错误处理,确保重试逻辑在遇到429等状态码时正常工作,避免提前退出。
修复 usage 未正确记录的问题,确保 token 使用情况被写入数据库。
我已经测试过可以正常使用,所以提交该pr
---
src/plugins/models/utils_model.py | 502 ++++++------------------------
1 file changed, 100 insertions(+), 402 deletions(-)
diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py
index 7572460f..708c2cda 100644
--- a/src/plugins/models/utils_model.py
+++ b/src/plugins/models/utils_model.py
@@ -5,34 +5,20 @@ from datetime import datetime
from typing import Tuple, Union
import aiohttp
-from src.common.logger import get_module_logger
+from loguru import logger
from nonebot import get_driver
import base64
from PIL import Image
import io
-from ...common.database import db
+from ...common.database import Database
from ..chat.config import global_config
driver = get_driver()
config = driver.config
-logger = get_module_logger("model_utils")
-
class LLM_request:
- # 定义需要转换的模型列表,作为类变量避免重复
- MODELS_NEEDING_TRANSFORMATION = [
- "o3-mini",
- "o1-mini",
- "o1-preview",
- "o1-2024-12-17",
- "o1-preview-2024-09-12",
- "o3-mini-2025-01-31",
- "o1-mini-2024-09-12",
- ]
-
def __init__(self, model, **kwargs):
- # 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
self.base_url = getattr(config, model["base_url"])
@@ -46,39 +32,21 @@ class LLM_request:
self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0)
- # 获取数据库实例
+ self.db = Database.get_instance()
self._init_database()
- @staticmethod
- def _init_database():
- """初始化数据库集合"""
+ def _init_database(self):
try:
- # 创建llm_usage集合的索引
- db.llm_usage.create_index([("timestamp", 1)])
- db.llm_usage.create_index([("model_name", 1)])
- db.llm_usage.create_index([("user_id", 1)])
- db.llm_usage.create_index([("request_type", 1)])
+ self.db.db.llm_usage.create_index([("timestamp", 1)])
+ self.db.db.llm_usage.create_index([("model_name", 1)])
+ self.db.db.llm_usage.create_index([("user_id", 1)])
+ self.db.db.llm_usage.create_index([("request_type", 1)])
except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}")
- def _record_usage(
- self,
- prompt_tokens: int,
- completion_tokens: int,
- total_tokens: int,
- user_id: str = "system",
- request_type: str = "chat",
- endpoint: str = "/chat/completions",
- ):
- """记录模型使用情况到数据库
- Args:
- prompt_tokens: 输入token数
- completion_tokens: 输出token数
- total_tokens: 总token数
- user_id: 用户ID,默认为system
- request_type: 请求类型(chat/embedding/image等)
- endpoint: API端点
- """
+ def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
+ user_id: str = "system", request_type: str = "chat",
+ endpoint: str = "/chat/completions"):
try:
usage_data = {
"model_name": self.model_name,
@@ -90,325 +58,130 @@ class LLM_request:
"total_tokens": total_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
"status": "success",
- "timestamp": datetime.now(),
+ "timestamp": datetime.now()
}
- db.llm_usage.insert_one(usage_data)
+ self.db.db.llm_usage.insert_one(usage_data)
logger.info(
- f"Token使用情况 - 模型: {self.model_name}, "
- f"用户: {user_id}, 类型: {request_type}, "
- f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
- f"总计: {total_tokens}"
+ f"Token使用情况 - 模型: {self.model_name}, 用户: {user_id}, 类型: {request_type}, "
+ f"提示词: {prompt_tokens}, 完成: {completion_tokens}, 总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
- """计算API调用成本
- 使用模型的pri_in和pri_out价格计算输入和输出的成本
-
- Args:
- prompt_tokens: 输入token数量
- completion_tokens: 输出token数量
-
- Returns:
- float: 总成本(元)
- """
- # 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
async def _execute_request(
- self,
- endpoint: str,
- prompt: str = None,
- image_base64: str = None,
- image_format: str = None,
- payload: dict = None,
- retry_policy: dict = None,
- response_handler: callable = None,
- user_id: str = "system",
- request_type: str = "chat",
+ self,
+ endpoint: str,
+ prompt: str = None,
+ image_base64: str = None,
+ payload: dict = None,
+ retry_policy: dict = None,
+ response_handler: callable = None,
+ user_id: str = "system",
+ request_type: str = "chat"
):
- """统一请求执行入口
- Args:
- endpoint: API端点路径 (如 "chat/completions")
- prompt: prompt文本
- image_base64: 图片的base64编码
- image_format: 图片格式
- payload: 请求体数据
- retry_policy: 自定义重试策略
- response_handler: 自定义响应处理器
- user_id: 用户ID
- request_type: 请求类型
- """
- # 合并重试策略
default_retry = {
- "max_retries": 3,
- "base_wait": 15,
+ "max_retries": 3, "base_wait": 15,
"retry_codes": [429, 413, 500, 503],
- "abort_codes": [400, 401, 402, 403],
+ "abort_codes": [400, 401, 402, 403]
}
policy = {**default_retry, **(retry_policy or {})}
- # 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
- 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env.prod中的配置是否正确哦~",
+ 401: "API key 错误,认证失败",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
- 503: "服务器负载过高",
+ 503: "服务器负载过高"
}
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
- # 判断是否为流式
stream_mode = self.params.get("stream", False)
- logger_msg = "进入流式输出模式," if stream_mode else ""
- # logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
- # logger.info(f"使用模型: {self.model_name}")
+ logger.debug(f"发送请求到URL: {api_url}, 流式模式: {stream_mode}")
+ logger.info(f"使用模型: {self.model_name}")
- # 构建请求体
if image_base64:
- payload = await self._build_payload(prompt, image_base64, image_format)
+ payload = await self._build_payload(prompt, image_base64)
elif payload is None:
payload = await self._build_payload(prompt)
for retry in range(policy["max_retries"]):
try:
- # 使用上下文管理器处理会话
headers = await self._build_headers()
- # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if stream_mode:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=payload) as response:
- # 处理需要重试的状态码
if response.status in policy["retry_codes"]:
- wait_time = policy["base_wait"] * (2**retry)
+ wait_time = policy["base_wait"] * (2 ** retry)
logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试")
- if response.status == 413:
- logger.warning("请求体过大,尝试压缩...")
+ if response.status == 413 and image_base64:
+ logger.warning("请求体过大,尝试压缩图片...")
image_base64 = compress_base64_image_by_scale(image_base64)
- payload = await self._build_payload(prompt, image_base64, image_format)
- elif response.status in [500, 503]:
- logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
- raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
- else:
- logger.warning(f"请求限制(429),等待{wait_time}秒后重试...")
-
+ payload = await self._build_payload(prompt, image_base64)
await asyncio.sleep(wait_time)
continue
elif response.status in policy["abort_codes"]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
- # 尝试获取并记录服务器返回的详细错误信息
- try:
- error_json = await response.json()
- if error_json and isinstance(error_json, list) and len(error_json) > 0:
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj = error_item["error"]
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(
- f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- # 处理单个错误对象的情况
- error_obj = error_json.get("error", {})
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(
- f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
- )
- else:
- # 记录原始错误响应内容
- logger.error(f"服务器错误响应: {error_json}")
- except Exception as e:
- logger.warning(f"无法解析服务器错误响应: {str(e)}")
-
- if response.status == 403:
- # 只针对硅基流动的V3和R1进行降级处理
- if (
- self.model_name.startswith("Pro/deepseek-ai")
- and self.base_url == "https://api.siliconflow.cn/v1/"
- ):
- old_model_name = self.model_name
- self.model_name = self.model_name[4:] # 移除"Pro/"前缀
- logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
-
- # 对全局配置进行更新
- if global_config.llm_normal.get("name") == old_model_name:
- global_config.llm_normal["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
-
- if global_config.llm_reasoning.get("name") == old_model_name:
- global_config.llm_reasoning["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
-
- # 更新payload中的模型名
- if payload and "model" in payload:
- payload["model"] = self.model_name
-
- # 重新尝试请求
- retry -= 1 # 不计入重试次数
- continue
-
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
response.raise_for_status()
- # 将流式输出转化为非流式输出
if stream_mode:
- flag_delta_content_finished = False
accumulated_content = ""
- usage = None # 初始化usage变量,避免未定义错误
-
+ usage = None
async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip()
- if not line:
- continue
+ if not line or line == "data: [DONE]":
+ break
if line.startswith("data:"):
- data_str = line[5:].strip()
- if data_str == "[DONE]":
+ chunk = json.loads(line[5:].strip())
+ delta = chunk["choices"][0]["delta"]
+ content = delta.get("content", "")
+ accumulated_content += content
+ if chunk["choices"][0].get("finish_reason") == "stop":
+ usage = chunk.get("usage")
break
- try:
- chunk = json.loads(data_str)
- if flag_delta_content_finished:
- chunk_usage = chunk.get("usage",None)
- if chunk_usage:
- usage = chunk_usage # 获取token用量
- else:
- delta = chunk["choices"][0]["delta"]
- delta_content = delta.get("content")
- if delta_content is None:
- delta_content = ""
- accumulated_content += delta_content
- # 检测流式输出文本是否结束
- finish_reason = chunk["choices"][0].get("finish_reason")
- if finish_reason == "stop":
- chunk_usage = chunk.get("usage",None)
- if chunk_usage:
- usage = chunk_usage
- break
- # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
- flag_delta_content_finished = True
-
- except Exception as e:
- logger.exception(f"解析流式输出错误: {str(e)}")
content = accumulated_content
- reasoning_content = ""
- think_match = re.search(r"(.*?)", content, re.DOTALL)
- if think_match:
- reasoning_content = think_match.group(1).strip()
- content = re.sub(r".*?", "", content, flags=re.DOTALL).strip()
- # 构造一个伪result以便调用自定义响应处理器或默认处理器
- result = {
- "choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}],
- "usage": usage,
- }
- return (
- response_handler(result)
- if response_handler
- else self._default_response_handler(result, user_id, request_type, endpoint)
- )
+ result = {"choices": [{"message": {"content": content}}], "usage": usage}
+ return response_handler(result) if response_handler else self._default_response_handler(
+ result, user_id, request_type, endpoint)
else:
result = await response.json()
- # 使用自定义处理器或默认处理
- return (
- response_handler(result)
- if response_handler
- else self._default_response_handler(result, user_id, request_type, endpoint)
- )
+ return response_handler(result) if response_handler else self._default_response_handler(
+ result, user_id, request_type, endpoint)
- except aiohttp.ClientResponseError as e:
- # 处理aiohttp抛出的响应错误
- if retry < policy["max_retries"] - 1:
- wait_time = policy["base_wait"] * (2**retry)
- logger.error(f"HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
- try:
- if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
- error_text = await e.response.text()
- try:
- error_json = json.loads(error_text)
- if isinstance(error_json, list) and len(error_json) > 0:
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj = error_item["error"]
- logger.error(
- f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- error_obj = error_json.get("error", {})
- logger.error(
- f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
- )
- else:
- logger.error(f"服务器错误响应: {error_json}")
- except (json.JSONDecodeError, TypeError) as json_err:
- logger.warning(f"响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
- except (AttributeError, TypeError, ValueError) as parse_err:
- logger.warning(f"无法解析响应错误内容: {str(parse_err)}")
-
- await asyncio.sleep(wait_time)
- else:
- logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
- # 安全地检查和记录请求详情
- if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
- if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
- content = payload["messages"][0]["content"]
- if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
- payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
- )
- logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
- raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
except Exception as e:
if retry < policy["max_retries"] - 1:
- wait_time = policy["base_wait"] * (2**retry)
+ wait_time = policy["base_wait"] * (2 ** retry)
logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.critical(f"请求失败: {str(e)}")
- # 安全地检查和记录请求详情
- if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
- if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
- content = payload["messages"][0]["content"]
- if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
- payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
- )
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}")
- logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数,API请求仍然失败")
async def _transform_parameters(self, params: dict) -> dict:
- """
- 根据模型名称转换参数:
- - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
- 并将 'max_tokens' 重命名为 'max_completion_tokens'
- """
- # 复制一份参数,避免直接修改原始数据
new_params = dict(params)
-
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
- # 删除 'temperature' 参数(如果存在)
+ models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
+ "o3-mini-2025-01-31", "o1-mini-2024-09-12"]
+ if self.model_name.lower() in models_needing_transformation:
new_params.pop("temperature", None)
- # 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
if "max_tokens" in new_params:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params
- async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
- """构建请求体"""
- # 复制一份参数,避免直接修改 self.params
+ async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
params_copy = await self._transform_parameters(self.params)
if image_base64:
payload = {
@@ -418,43 +191,29 @@ class LLM_request:
"role": "user",
"content": [
{"type": "text", "text": prompt},
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
- },
- ],
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
+ ]
}
],
"max_tokens": global_config.max_response_length,
- **params_copy,
+ **params_copy
}
else:
payload = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
- **params_copy,
+ **params_copy
}
- # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
- payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload
- def _default_response_handler(
- self, result: dict, user_id: str = "system", request_type: str = "chat", endpoint: str = "/chat/completions"
- ) -> Tuple:
- """默认响应解析"""
+ def _default_response_handler(self, result: dict, user_id: str = "system",
+ request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
if "choices" in result and result["choices"]:
message = result["choices"][0]["message"]
content = message.get("content", "")
- content, reasoning = self._extract_reasoning(content)
- reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = message.get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = reasoning
+ reasoning_content = message.get("reasoning_content", "")
- # 记录token使用情况
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
@@ -466,162 +225,101 @@ class LLM_request:
total_tokens=total_tokens,
user_id=user_id,
request_type=request_type,
- endpoint=endpoint,
+ endpoint=endpoint
)
-
return content, reasoning_content
-
return "没有返回结果", ""
- @staticmethod
- def _extract_reasoning(content: str) -> Tuple[str, str]:
- """CoT思维链提取"""
- match = re.search(r"(?:)?(.*?)", content, re.DOTALL)
- content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip()
- if match:
- reasoning = match.group(1).strip()
- else:
- reasoning = ""
- return content, reasoning
-
async def _build_headers(self, no_key: bool = False) -> dict:
- """构建请求头"""
if no_key:
- return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
- else:
- return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
- # 防止小朋友们截图自己的key
+ return {
+ "Authorization": "Bearer **********",
+ "Content-Type": "application/json"
+ }
+ return {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json"
+ }
async def generate_response(self, prompt: str) -> Tuple[str, str]:
- """根据输入的提示生成模型的异步响应"""
-
- content, reasoning_content = await self._execute_request(endpoint="/chat/completions", prompt=prompt)
- return content, reasoning_content
-
- async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]:
- """根据输入的提示和图片生成模型的异步响应"""
-
- content, reasoning_content = await self._execute_request(
- endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format
+ return await self._execute_request(
+ endpoint="/chat/completions",
+ prompt=prompt
+ )
+
+ async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
+ return await self._execute_request(
+ endpoint="/chat/completions",
+ prompt=prompt,
+ image_base64=image_base64
)
- return content, reasoning_content
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple[str, str]]:
- """异步方式根据输入的提示生成模型的响应"""
- # 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
- **self.params,
- **kwargs,
+ **self.params
}
-
- content, reasoning_content = await self._execute_request(
- endpoint="/chat/completions", payload=data, prompt=prompt
+ return await self._execute_request(
+ endpoint="/chat/completions",
+ payload=data,
+ prompt=prompt
)
- return content, reasoning_content
async def get_embedding(self, text: str) -> Union[list, None]:
- """异步方法:获取文本的embedding向量
-
- Args:
- text: 需要获取embedding的文本
-
- Returns:
- list: embedding向量,如果失败则返回None
- """
-
- if(len(text) < 1):
- logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
- return None
def embedding_handler(result):
- """处理响应"""
if "data" in result and len(result["data"]) > 0:
return result["data"][0].get("embedding", None)
return None
- embedding = await self._execute_request(
+ return await self._execute_request(
endpoint="/embeddings",
prompt=text,
- payload={"model": self.model_name, "input": text, "encoding_format": "float"},
+ payload={
+ "model": self.model_name,
+ "input": text,
+ "encoding_format": "float"
+ },
retry_policy={"max_retries": 2, "base_wait": 6},
- response_handler=embedding_handler,
+ response_handler=embedding_handler
)
- return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
- """压缩base64格式的图片到指定大小
- Args:
- base64_data: base64编码的图片数据
- target_size: 目标文件大小(字节),默认0.8MB
- Returns:
- str: 压缩后的base64图片数据
- """
try:
- # 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
-
- # 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2 * 1024 * 1024:
return base64_data
- # 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
-
- # 获取原始尺寸
original_width, original_height = img.size
-
- # 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
-
- # 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
- # 创建内存缓冲区
output_buffer = io.BytesIO()
-
- # 如果是GIF,处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
- new_frame = img.copy()
- new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
+ new_frame = img.copy().resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS)
frames.append(new_frame)
-
- # 保存到缓冲区
frames[0].save(
- output_buffer,
- format="GIF",
- save_all=True,
- append_images=frames[1:],
- optimize=True,
- duration=img.info.get("duration", 100),
- loop=img.info.get("loop", 0),
+ output_buffer, format='GIF', save_all=True, append_images=frames[1:],
+ optimize=True, duration=img.info.get('duration', 100), loop=img.info.get('loop', 0)
)
else:
- # 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
-
- # 保存到缓冲区,保持原始格式
- if img.format == "PNG" and img.mode in ("RGBA", "LA"):
- resized_img.save(output_buffer, format="PNG", optimize=True)
+ if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
+ resized_img.save(output_buffer, format='PNG', optimize=True)
else:
- resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True)
+ resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
- # 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
- logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB")
-
- return base64.b64encode(compressed_data).decode("utf-8")
-
+ logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
+ return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
- import traceback
-
- logger.error(traceback.format_exc())
return base64_data