diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py
index ccc414c9..a51eba5e 100644
--- a/src/chat/heart_flow/observation/chatting_observation.py
+++ b/src/chat/heart_flow/observation/chatting_observation.py
@@ -81,7 +81,7 @@ class ChattingObservation(Observation):
mid_memory_str = ""
if ids:
for id in ids:
- # print(f"id:{id}")
+ print(f"id:{id}")
try:
for mid_memory in self.mid_memorys:
if mid_memory["id"] == id:
diff --git a/src/chat/heart_flow/subheartflow_manager.py b/src/chat/heart_flow/subheartflow_manager.py
index 1ab17339..cd452e53 100644
--- a/src/chat/heart_flow/subheartflow_manager.py
+++ b/src/chat/heart_flow/subheartflow_manager.py
@@ -284,7 +284,7 @@ class SubHeartflowManager:
return # 如果不允许,直接返回
# --- 结束新增 ---
- logger.debug(f"当前状态 ({current_state.value}) 可以在{focused_limit}个群 专注聊天")
+ logger.info(f"当前状态 ({current_state.value}) 可以在{focused_limit}个群 专注聊天")
if focused_limit <= 0:
# logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 不允许 FOCUSED 子心流")
@@ -402,7 +402,7 @@ class SubHeartflowManager:
_mai_state_description = f"你当前状态: {current_mai_state.value}。"
individuality = Individuality.get_instance()
personality_prompt = individuality.get_prompt(x_person=2, level=3)
- prompt_personality = f"你是{individuality.name},{personality_prompt}"
+ prompt_personality = f"你正在扮演名为{individuality.name}的人类,{personality_prompt}"
# --- 修改:在 prompt 中加入当前聊天计数和群名信息 (条件显示) ---
chat_status_lines = []
diff --git a/src/chat/knowledge/src/qa_manager.py b/src/chat/knowledge/src/qa_manager.py
index 8f9266d6..06c21e88 100644
--- a/src/chat/knowledge/src/qa_manager.py
+++ b/src/chat/knowledge/src/qa_manager.py
@@ -61,7 +61,7 @@ class QAManager:
for res in relation_search_res:
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
- logger.debug(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
+ print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
# TODO: 使用LLM过滤三元组结果
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
@@ -77,16 +77,16 @@ class QAManager:
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
if len(relation_search_res) != 0:
- logger.debug("找到相关关系,将使用RAG进行检索")
+ logger.info("找到相关关系,将使用RAG进行检索")
# 使用KG检索
part_start_time = time.perf_counter()
result, ppr_node_weights = self.kg_manager.kg_search(
relation_search_res, paragraph_search_res, self.embed_manager
)
part_end_time = time.perf_counter()
- logger.debug(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
+ logger.infoinfo(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
else:
- logger.debug("未找到相关关系,将使用文段检索结果")
+ logger.infoinfo("未找到相关关系,将使用文段检索结果")
result = paragraph_search_res
ppr_node_weights = None
@@ -95,7 +95,7 @@ class QAManager:
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
- logger.debug(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
+ logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights
else:
diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py
index b5b0f6e7..a42a11a8 100644
--- a/src/chat/message_receive/message.py
+++ b/src/chat/message_receive/message.py
@@ -1,4 +1,3 @@
-# TODO: 原生多模态支持
import time
from abc import abstractmethod
from dataclasses import dataclass
diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py
index ee810bf3..35f445a2 100644
--- a/src/chat/models/utils_model.py
+++ b/src/chat/models/utils_model.py
@@ -1,43 +1,29 @@
import asyncio
import json
-import random # 添加 random 模块导入
import re
from datetime import datetime
-from typing import Tuple, Union, Dict, Any, Set # 引入 Set
+from typing import Tuple, Union, Dict, Any
import aiohttp
from aiohttp.client import ClientResponse
-# 相对路径导入,根据你的项目结构调整
-# 例如,如果 utils_model.py 在 src/utils/ 下,而 logger 在 src/common/ 下
-# from ..common.logger import get_module_logger
-# from ..common.database import db
-# from ..config.config import global_config
-# 假设它们在期望的路径
from src.common.logger import get_module_logger
-from ...common.database import db
-from ...config.config import global_config
-
-
import base64
from PIL import Image
import io
import os
-
+from ...common.database import db
+from ...config.config import global_config
from rich.traceback import install
install(extra_lines=3)
-# 尝试加载 .env 文件中的环境变量 (如果项目结构需要)
-# load_dotenv() # 如果你的 .env 文件不在标准位置,可能需要指定路径 load_dotenv(dotenv_path='path/to/.env')
-
logger = get_module_logger("model_utils")
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
- # (代码不变)
def __init__(self, message: str):
super().__init__(message)
self.message = message
@@ -49,7 +35,6 @@ class PayLoadTooLargeError(Exception):
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
- # (代码不变)
def __init__(self, message: str, response: ClientResponse):
super().__init__(message)
self.message = message
@@ -62,31 +47,20 @@ class RequestAbortException(Exception):
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
- # (代码不变)
- def __init__(self, message: str, key_identifier: str = None): # 添加 key 标识符
+ def __init__(self, message: str):
super().__init__(message)
self.message = message
- self.key_identifier = key_identifier # 存储导致 403 的 key
def __str__(self):
return self.message
-# 新增:用于内部标记需要切换 Key 的异常
-class _SwitchKeyException(Exception):
- """内部异常,用于标记需要切换Key并且跳过标准等待时间."""
-
- # (代码不变)
- pass
-
-
# 常见Error Code Mapping
error_code_mapping = {
- # (代码不变)
400: "参数不正确",
- 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", # 401 也可能是 Key 无效
+ 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
402: "账号余额不足",
- 403: "需要实名,或余额不足,或Key无权限", # 扩展 403 的含义
+ 403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
@@ -95,42 +69,29 @@ error_code_mapping = {
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
- """安全地记录请求内容,隐藏敏感信息"""
- # (代码不变)
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
- is_gemini_payload = payload and isinstance(payload, dict) and "contents" in payload
- safe_payload = json.loads(json.dumps(payload)) if payload else {}
-
- if image_base64 and safe_payload and isinstance(safe_payload, dict):
- if "messages" in safe_payload and len(safe_payload["messages"]) > 0:
- if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
- content = safe_payload["messages"][0]["content"]
- if (
- isinstance(content, list)
- and len(content) > 1
- and isinstance(content[1], dict)
- and "image_url" in content[1]
- ):
- safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
- f"{image_base64[:10]}...{image_base64[-10:]}"
- )
- elif is_gemini_payload and "contents" in safe_payload and len(safe_payload["contents"]) > 0:
- if isinstance(safe_payload["contents"][0], dict) and "parts" in safe_payload["contents"][0]:
- parts = safe_payload["contents"][0]["parts"]
- for i, part in enumerate(parts):
- if isinstance(part, dict) and "inlineData" in part:
- safe_payload["contents"][0]["parts"][i]["inlineData"]["data"] = (
- f"{image_base64[:10]}...{image_base64[-10:]}"
- )
- break
-
- return safe_payload
+ if (
+ image_base64
+ and payload
+ and isinstance(payload, dict)
+ and "messages" in payload
+ and len(payload["messages"]) > 0
+ ):
+ if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
+ content = payload["messages"][0]["content"]
+ if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
+ payload["messages"][0]["content"][1]["image_url"]["url"] = (
+ f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
+ f"{image_base64[:10]}...{image_base64[-10:]}"
+ )
+ # if isinstance(content, str) and len(content) > 100:
+ # payload["messages"][0]["content"] = content[:100]
+ return payload
class LLMRequest:
- # (代码不变)
+ # 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
@@ -146,170 +107,34 @@ class LLMRequest:
"o3-mini-2025-01-31o4-mini",
"o4-mini-2025-04-16",
]
- _abandoned_keys_runtime: Set[str] = set()
def __init__(self, model: dict, **kwargs):
- """初始化 LLMRequest 实例"""
- # (代码不变)
- self.model_key_name = model["key"]
- self.model_name: str = model["name"]
- self.params = kwargs
- self.stream = model.get("stream", False)
- self.pri_in = model.get("pri_in", 0)
- self.pri_out = model.get("pri_out", 0)
- self.request_type = model.get("request_type", "default")
-
+ # 将大写的配置键转换为小写并从config中获取实际值
try:
- raw_api_key_config = os.environ[self.model_key_name]
+ self.api_key = os.environ[model["key"]]
self.base_url = os.environ[model["base_url"]]
- self.is_gemini = "googleapis.com" in self.base_url.lower()
- if self.is_gemini:
- logger.debug(f"模型 {self.model_name}: 检测到为 Gemini API (Base URL: {self.base_url})")
- if self.stream:
- logger.warning(f"模型 {self.model_name}: Gemini 流式输出处理与 OpenAI 不同,暂时强制禁用流式。")
- self.stream = False
-
- # 解析和过滤 API Keys (代码不变)
- parsed_keys = []
- is_list_config = False
- try:
- loaded_keys = json.loads(raw_api_key_config)
- if isinstance(loaded_keys, list):
- parsed_keys = [str(key) for key in loaded_keys if key]
- is_list_config = True
- elif isinstance(loaded_keys, str) and loaded_keys:
- parsed_keys = [loaded_keys]
- else:
- raise ValueError(f"Parsed API key for {self.model_key_name} is not a valid list or string.")
- except (json.JSONDecodeError, TypeError) as e:
- if isinstance(raw_api_key_config, list):
- parsed_keys = [str(key) for key in raw_api_key_config if key]
- is_list_config = True
- elif isinstance(raw_api_key_config, str) and raw_api_key_config:
- parsed_keys = [raw_api_key_config]
- else:
- raise ValueError(
- f"Invalid or empty API key config for {self.model_key_name}: {raw_api_key_config}"
- ) from e
-
- if not parsed_keys:
- raise ValueError(f"No valid API keys found for {self.model_key_name}.")
-
- abandoned_key_name = f"abandon_{self.model_key_name}"
- abandoned_keys_set = set()
- raw_abandoned_keys = os.environ.get(abandoned_key_name)
-
- if raw_abandoned_keys:
- try:
- loaded_abandoned = json.loads(raw_abandoned_keys)
- if isinstance(loaded_abandoned, list):
- abandoned_keys_set.update(str(key) for key in loaded_abandoned if key)
- elif isinstance(loaded_abandoned, str) and loaded_abandoned:
- abandoned_keys_set.add(loaded_abandoned)
- logger.info(
- f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' 的废弃 Keys。"
- )
- except (json.JSONDecodeError, TypeError):
- if isinstance(raw_abandoned_keys, list):
- abandoned_keys_set.update(str(key) for key in raw_abandoned_keys if key)
- logger.info(
- f"模型 {model['name']}: 加载了 {len(abandoned_keys_set)} 个来自配置 '{abandoned_key_name}' (直接列表) 的废弃 Keys。"
- )
- elif isinstance(raw_abandoned_keys, str) and raw_abandoned_keys:
- abandoned_keys_set.add(raw_abandoned_keys)
- logger.info(
- f"模型 {model['name']}: 加载了 1 个来自配置 '{abandoned_key_name}' (字符串) 的废弃 Key。"
- )
- else:
- logger.warning(f"无法解析环境变量 '{abandoned_key_name}' 的内容: {raw_abandoned_keys}")
-
- all_abandoned_keys = abandoned_keys_set.union(LLMRequest._abandoned_keys_runtime)
- active_keys = [key for key in parsed_keys if key not in all_abandoned_keys]
-
- if not active_keys:
- logger.error(f"模型 {model['name']}: 所有为 '{self.model_key_name}' 配置的 Keys 都已被废弃或无效。")
- raise ValueError(
- f"No active API keys available for {self.model_key_name} after filtering abandoned keys."
- )
-
- if is_list_config and len(active_keys) > 1:
- self._api_key_config = active_keys
- logger.info(
- f"模型 {model['name']}: 初始化完成,可用 Keys: {len(self._api_key_config)} (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
- )
- elif active_keys:
- self._api_key_config = active_keys[0]
- logger.info(
- f"模型 {model['name']}: 初始化完成,使用单个活动 Key (已排除 {len(all_abandoned_keys)} 个废弃 Keys)。"
- )
- else:
- raise ValueError(f"Unexpected state: No active keys for {self.model_key_name}.")
-
- # 加载代理配置 (代码不变)
- self.proxy_url = None
- self.proxy_models_set = set()
- proxy_host = os.environ.get("PROXY_HOST")
- proxy_port = os.environ.get("PROXY_PORT")
- proxy_models_str = os.environ.get("PROXY_MODELS", "")
-
- if proxy_host and proxy_port:
- try:
- int(proxy_port)
- self.proxy_url = f"http://{proxy_host}:{proxy_port}"
- logger.debug(f"代理已配置: {self.proxy_url}")
-
- if proxy_models_str:
- try:
- cleaned_str = proxy_models_str.strip("'\"")
- self.proxy_models_set = {
- model_name.strip() for model_name in cleaned_str.split(",") if model_name.strip()
- }
- logger.debug(f"以下模型将使用代理: {self.proxy_models_set}")
- except Exception as e:
- logger.error(
- f"解析 PROXY_MODELS ('{proxy_models_str}') 出错: {e}. 代理将不会对特定模型生效。"
- )
- self.proxy_models_set = set()
- except ValueError:
- logger.error(f"无效的代理端口号: {proxy_port}。代理将不被启用。")
- self.proxy_url = None
- self.proxy_models_set = set()
- except Exception as e:
- logger.error(f"加载代理配置时发生错误: {e}")
- self.proxy_url = None
- self.proxy_models_set = set()
- else:
- logger.info("未配置代理服务器 (PROXY_HOST 或 PROXY_PORT 未设置)。")
-
- except KeyError as e:
- # (代码不变)
- missing_key = str(e).strip("'")
- if missing_key == self.model_key_name:
- logger.error(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'")
- raise ValueError(f"配置错误:找不到 API Key 环境变量 '{self.model_key_name}'") from e
- elif missing_key == model["base_url"]:
- logger.error(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'")
- raise ValueError(f"配置错误:找不到 Base URL 环境变量 '{model['base_url']}'") from e
- else:
- logger.error(f"配置错误:找不到环境变量 - {str(e)}")
- raise ValueError(f"配置错误:找不到环境变量 - {str(e)}") from e
except AttributeError as e:
- # (代码不变)
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
- except ValueError as e:
- # (代码不变)
- logger.error(f"API Key 或配置初始化错误 for {self.model_key_name}: {str(e)}")
- raise e
+ self.model_name: str = model["name"]
+ self.params = kwargs
+ self.stream = model.get("stream", False)
+ self.pri_in = model.get("pri_in", 0)
+ self.pri_out = model.get("pri_out", 0)
+
+ # 获取数据库实例
self._init_database()
+ # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
+ self.request_type = kwargs.pop("request_type", "default")
+
@staticmethod
def _init_database():
"""初始化数据库集合"""
- # (代码不变)
try:
+ # 创建llm_usage集合的索引
db.llm_usage.create_index([("timestamp", 1)])
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
@@ -339,19 +164,12 @@ class LLMRequest:
if request_type is None:
request_type = self.request_type
- actual_endpoint = endpoint
- if self.is_gemini:
- if endpoint == "/embeddings":
- actual_endpoint = ":embedContent"
- else:
- actual_endpoint = ":generateContent"
-
try:
usage_data = {
"model_name": self.model_name,
"user_id": user_id,
- "request_type": request_type or self.request_type,
- "endpoint": actual_endpoint,
+ "request_type": request_type,
+ "endpoint": endpoint,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
@@ -362,7 +180,7 @@ class LLMRequest:
db.llm_usage.insert_one(usage_data)
logger.trace(
f"Token使用情况 - 模型: {self.model_name}, "
- f"用户: {user_id}, 类型: {request_type or self.request_type}, "
+ f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
@@ -370,8 +188,17 @@ class LLMRequest:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
- """计算API调用成本"""
- # (代码不变)
+ """计算API调用成本
+ 使用模型的pri_in和pri_out价格计算输入和输出的成本
+
+ Args:
+ prompt_tokens: 输入token数量
+ completion_tokens: 输出token数量
+
+ Returns:
+ float: 总成本(元)
+ """
+ # 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
@@ -384,43 +211,46 @@ class LLMRequest:
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
- **kwargs: Any,
) -> Dict[str, Any]:
- """配置请求参数,合并实例参数和调用时参数"""
+ """配置请求参数
+ Args:
+ endpoint: API端点路径 (如 "chat/completions")
+ prompt: prompt文本
+ image_base64: 图片的base64编码
+ image_format: 图片格式
+ payload: 请求体数据
+ retry_policy: 自定义重试策略
+ request_type: 请求类型
+ """
+
+ # 合并重试策略
default_retry = {
- "max_retries": global_config.api_polling_max_retries,
+ "max_retries": 3,
"base_wait": 10,
"retry_codes": [429, 413, 500, 503],
"abort_codes": [400, 401, 402, 403],
}
policy = {**default_retry, **(retry_policy or {})}
- _actual_endpoint = endpoint
- if self.is_gemini:
- action = endpoint.lstrip("/")
- api_url = f"{self.base_url.rstrip('/')}/{self.model_name}{action}"
- stream_mode = False
- else:
- api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
- stream_mode = self.stream
+ api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
- call_params = {k: v for k, v in kwargs.items() if k != "request_type"}
- merged_params = {**self.params, **call_params}
+ stream_mode = self.stream
- if payload is None:
- payload = await self._build_payload(prompt, image_base64, image_format, merged_params)
- else:
- logger.debug("使用外部提供的 payload,忽略单次调用参数合并。")
+ # 构建请求体
+ if image_base64:
+ payload = await self._build_payload(prompt, image_base64, image_format)
+ elif payload is None:
+ payload = await self._build_payload(prompt)
- if not self.is_gemini and stream_mode:
- payload["stream"] = merged_params.get("stream", stream_mode)
+ if stream_mode:
+ payload["stream"] = stream_mode
return {
"policy": policy,
"payload": payload,
"api_url": api_url,
- "stream_mode": payload.get("stream", False),
- "image_base64": image_base64,
+ "stream_mode": stream_mode,
+ "image_base64": image_base64, # 保留必要的exception处理所需的原始数据
"image_format": image_format,
"prompt": prompt,
}
@@ -436,286 +266,87 @@ class LLMRequest:
response_handler: callable = None,
user_id: str = "system",
request_type: str = None,
- **kwargs: Any,
):
- """统一请求执行入口, 支持列表 key 切换、代理和单次调用参数覆盖"""
- final_request_type = request_type or kwargs.get("request_type") or self.request_type
- api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
-
+ """统一请求执行入口
+ Args:
+ endpoint: API端点路径 (如 "chat/completions")
+ prompt: prompt文本
+ image_base64: 图片的base64编码
+ image_format: 图片格式
+ payload: 请求体数据
+ retry_policy: 自定义重试策略
+ response_handler: 自定义响应处理器
+ user_id: 用户ID
+ request_type: 请求类型
+ """
+ # 获取请求配置
request_content = await self._prepare_request(
- endpoint, prompt, image_base64, image_format, payload, retry_policy, **api_kwargs
+ endpoint, prompt, image_base64, image_format, payload, retry_policy
)
- policy = request_content["policy"]
- api_url = request_content["api_url"]
- actual_payload = request_content["payload"]
- stream_mode = request_content["stream_mode"]
-
- use_proxy = False
- current_proxy_url = None
- if self.proxy_url and self.model_name in self.proxy_models_set:
- use_proxy = True
- current_proxy_url = self.proxy_url
- logger.debug(f"模型 {self.model_name}: 将通过代理 {current_proxy_url} 发送请求。")
- elif self.proxy_url:
- logger.debug(f"模型 {self.model_name}: 配置了代理,但此模型不在 PROXY_MODELS 列表中,将不使用代理。")
- else:
- logger.debug(f"模型 {self.model_name}: 未配置或不为此模型使用代理。")
-
- current_key = None
- keys_failed_429 = set()
- keys_abandoned_runtime = set()
- key_switch_limit_429 = global_config.api_polling_max_retries
- key_switch_limit_403 = global_config.api_polling_max_retries
-
- available_keys_pool = []
- is_key_list = isinstance(self._api_key_config, list)
-
- if is_key_list:
- available_keys_pool = list(self._api_key_config)
- if not available_keys_pool:
- logger.error(f"模型 {self.model_name}: 初始化后无可用活动 Keys。")
- raise ValueError(f"模型 {self.model_name}: 无可用活动 Keys。")
- random.shuffle(available_keys_pool)
- key_switch_limit_429 = min(key_switch_limit_429, len(available_keys_pool))
- key_switch_limit_403 = min(key_switch_limit_403, len(available_keys_pool))
- logger.info(
- f"模型 {self.model_name}: Key 列表模式,启用 429/403 自动切换(429上限: {key_switch_limit_429}, 403上限: {key_switch_limit_403})。"
- )
- elif isinstance(self._api_key_config, str):
- available_keys_pool = [self._api_key_config]
- key_switch_limit_429 = 1
- key_switch_limit_403 = 1
- else:
- logger.error(f"模型 {self.model_name}: 无效的 API Key 配置类型在执行时遇到: {type(self._api_key_config)}")
- raise TypeError(f"模型 {self.model_name}: 无效的 API Key 配置类型")
-
- last_exception = None
-
- for attempt in range(policy["max_retries"]):
- if available_keys_pool:
- current_key = available_keys_pool.pop(0)
- elif current_key:
- logger.debug(
- f"模型 {self.model_name}: 无新 Key 可用或为单 Key 模式,将使用 Key ...{current_key[-4:]} 进行重试 (第 {attempt + 1} 次尝试)"
- )
- else:
- if (
- not self._api_key_config
- or all(
- k in LLMRequest._abandoned_keys_runtime
- for k in self._api_key_config
- if isinstance(self._api_key_config, list)
- )
- or (
- isinstance(self._api_key_config, str)
- and self._api_key_config in LLMRequest._abandoned_keys_runtime
- )
- ):
- final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
- logger.critical(final_error_msg)
- raise PermissionDeniedException(final_error_msg)
- else:
- raise RuntimeError(f"模型 {self.model_name}: 无法选择 API key (第 {attempt + 1} 次尝试)")
-
- logger.debug(f"模型 {self.model_name}: 尝试使用 Key: ...{current_key[-4:]} (总第 {attempt + 1} 次尝试)")
-
+ if request_type is None:
+ request_type = self.request_type
+ for retry in range(request_content["policy"]["max_retries"]):
try:
- headers = await self._build_headers(current_key)
- if not self.is_gemini and stream_mode:
+ # 使用上下文管理器处理会话
+ headers = await self._build_headers()
+ # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
+ if request_content["stream_mode"]:
headers["Accept"] = "text/event-stream"
-
async with aiohttp.ClientSession() as session:
- post_kwargs = {"headers": headers, "json": actual_payload, "timeout": 60}
- if use_proxy:
- post_kwargs["proxy"] = current_proxy_url
-
- async with session.post(api_url, **post_kwargs) as response:
- if response.status == 429 and is_key_list:
- logger.warning(f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 429 错误。")
- response_text = await response.text()
- logger.debug(
- f"模型 {self.model_name}: Key ...{current_key[-4:]} response:\n{json.dumps(json.loads(response_text), indent=2, ensure_ascii=False)}\napi_url:\n{api_url}\nheader:\n{headers}\npayload:\n{actual_payload}"
- )
- if current_key not in keys_failed_429:
- keys_failed_429.add(current_key)
- logger.info(
- f" (因 429 已失败 {len(keys_failed_429)}/{key_switch_limit_429} 个不同 Key)"
- )
- if available_keys_pool and len(keys_failed_429) < key_switch_limit_429:
- logger.info(" 尝试因 429 切换到下一个可用 Key...")
- raise _SwitchKeyException()
- else:
- logger.warning(" 无更多 Key 可因 429 切换或已达上限。")
- else:
- logger.warning(f" Key ...{current_key[-4:]} 再次遇到 429,按标准重试流程。")
-
- elif response.status == 403 and is_key_list:
- logger.error(
- f"模型 {self.model_name}: Key ...{current_key[-4:]} 遇到 403 (权限拒绝) 错误。"
- )
- if current_key not in keys_abandoned_runtime:
- keys_abandoned_runtime.add(current_key)
- LLMRequest._abandoned_keys_runtime.add(current_key)
- logger.critical(
- f" !! Key ...{current_key[-4:]} 已添加到运行时废弃列表。请考虑将其移至配置中的 'abandon_{self.model_key_name}' !!"
- )
- if current_key in available_keys_pool:
- available_keys_pool.remove(current_key)
- if available_keys_pool and len(keys_abandoned_runtime) < key_switch_limit_403:
- logger.info(" 尝试因 403 切换到下一个可用 Key...")
- raise _SwitchKeyException()
- else:
- logger.error(" 无更多 Key 可因 403 切换或已达上限。将中止请求。")
- await response.read()
- raise PermissionDeniedException(
- f"Key ...{current_key[-4:]} 权限被拒,且无其他可用 Key 切换。",
- key_identifier=current_key,
- )
- else:
- logger.error(f" Key ...{current_key[-4:]} 再次遇到 403,这不应发生。中止请求。")
- await response.read()
- raise PermissionDeniedException(
- f"Key ...{current_key[-4:]} 重复遇到 403。", key_identifier=current_key
- )
-
- elif response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
- await self._handle_error_response(response, attempt, policy, current_key)
-
- if response.status in policy["retry_codes"] and attempt < policy["max_retries"] - 1:
- if response.status not in [429, 403]:
- wait_time = policy["base_wait"] * (2**attempt)
- logger.warning(
- f"模型 {self.model_name}: 遇到可重试错误 {response.status}, 等待 {wait_time} 秒后重试..."
- )
- await asyncio.sleep(wait_time)
- last_exception = RuntimeError(f"重试错误 {response.status}")
- continue
-
- if response.status in policy["abort_codes"] or (
- response.status in policy["retry_codes"] and attempt >= policy["max_retries"] - 1
- ):
- if attempt >= policy["max_retries"] - 1 and response.status in policy["retry_codes"]:
- logger.error(
- f"模型 {self.model_name}: 达到最大重试次数,最后一次尝试仍为可重试错误 {response.status}。"
- )
- # await self._handle_error_response(response, attempt, policy, current_key)
- # await response.read()
- # final_error_msg = f"请求中止或达到最大重试次数,最终状态码: {response.status}"
- # logger.error(final_error_msg)
- # raise RequestAbortException(final_error_msg, response)
-
- response.raise_for_status()
- result = {}
- if not self.is_gemini and stream_mode:
- result = await self._handle_stream_output(response)
- else:
- result = await response.json()
-
- return (
- response_handler(result)
- if response_handler
- else self._default_response_handler(result, user_id, final_request_type, endpoint)
+ async with session.post(
+ request_content["api_url"], headers=headers, json=request_content["payload"]
+ ) as response:
+ handled_result = await self._handle_response(
+ response, request_content, retry, response_handler, user_id, request_type, endpoint
)
-
- except _SwitchKeyException:
- last_exception = _SwitchKeyException()
- logger.debug("捕获到 _SwitchKeyException,立即进行下一次尝试。")
- continue
- except PermissionDeniedException as e:
- logger.error(f"模型 {self.model_name}: 因权限拒绝 (403) 中止请求: {e}")
- if is_key_list and not available_keys_pool and e.key_identifier:
- logger.critical(f" 中止原因是 Key ...{e.key_identifier[-4:]} 触发 403 后已无其他 Key 可用。")
- raise e
- except aiohttp.ClientProxyConnectionError as e:
- logger.error(f"代理连接错误: {e} (代理地址: {current_proxy_url})")
- last_exception = e
- if attempt >= policy["max_retries"] - 1:
- raise RuntimeError(f"代理连接失败达到最大重试次数: {e}") from e
- wait_time = policy["base_wait"] * (2**attempt)
- logger.warning(f"模型 {self.model_name}: 代理连接错误,等待 {wait_time} 秒后重试...")
- await asyncio.sleep(wait_time)
- continue
- except aiohttp.ClientConnectorError as e:
- logger.error(f"网络连接错误: {e} (URL: {api_url}, 代理: {current_proxy_url})")
- last_exception = e
- if attempt >= policy["max_retries"] - 1:
- raise RuntimeError(f"网络连接失败达到最大重试次数: {e}") from e
- wait_time = policy["base_wait"] * (2**attempt)
- logger.warning(f"模型 {self.model_name}: 网络连接错误,等待 {wait_time} 秒后重试...")
- await asyncio.sleep(wait_time)
- continue
- except (PayLoadTooLargeError, RequestAbortException) as e:
- # (代码不变)
- logger.error(f"模型 {self.model_name}: 请求处理中遇到关键错误,将中止: {e}")
- raise e
+ return handled_result
except Exception as e:
- # (代码不变)
- last_exception = e
- logger.warning(
- f"模型 {self.model_name}: 第 {attempt + 1} 次尝试中发生非 HTTP 错误: {str(e.__class__.__name__)} - {str(e)}"
- )
+ handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
+ retry += count_delta # 降级不计入重试次数
+ if handled_payload:
+ # 如果降级成功,重新构建请求体
+ request_content["payload"] = handled_payload
+ continue
- if attempt >= policy["max_retries"] - 1:
- logger.error(
- f"模型 {self.model_name}: 达到最大重试次数 ({policy['max_retries']}),因非 HTTP 错误失败。"
- )
- else:
- try:
- temp_request_content = {
- "policy": policy,
- "payload": actual_payload,
- "api_url": api_url,
- "stream_mode": stream_mode,
- "image_base64": image_base64,
- "image_format": image_format,
- "prompt": prompt,
- }
- handled_payload, count_delta = await self._handle_exception(
- e, attempt, temp_request_content, merged_params=api_kwargs
- )
- if handled_payload:
- actual_payload = handled_payload
- logger.info(f"模型 {self.model_name}: 异常处理更新了 payload,将使用当前 Key 重试。")
+ logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
+ raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
- wait_time = policy["base_wait"] * (2**attempt)
- logger.warning(f"模型 {self.model_name}: 等待 {wait_time} 秒后重试...")
- await asyncio.sleep(wait_time)
- continue
+ async def _handle_response(
+ self,
+ response: ClientResponse,
+ request_content: Dict[str, Any],
+ retry_count: int,
+ response_handler: callable,
+ user_id,
+ request_type,
+ endpoint,
+ ) -> Union[Dict[str, Any], None]:
+ policy = request_content["policy"]
+ stream_mode = request_content["stream_mode"]
+ if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
+ await self._handle_error_response(response, retry_count, policy)
+ return None
- except (RequestAbortException, PermissionDeniedException) as abort_exception:
- logger.error(f"模型 {self.model_name}: 异常处理判断需要中止请求: {abort_exception}")
- raise abort_exception
- except RuntimeError as rt_error:
- logger.error(f"模型 {self.model_name}: 异常处理遇到运行时错误: {rt_error}")
- raise rt_error
-
- # --- 循环结束 ---
- logger.error(f"模型 {self.model_name}: 所有重试尝试 ({policy['max_retries']} 次) 均失败。")
- if last_exception:
- if isinstance(last_exception, PermissionDeniedException):
- logger.error(f"最后遇到的错误是权限拒绝: {str(last_exception)}")
- raise last_exception
- logger.error(f"最后遇到的错误: {str(last_exception.__class__.__name__)} - {str(last_exception)}")
- raise RuntimeError(
- f"模型 {self.model_name} 达到最大重试次数,API 请求失败。最后错误: {str(last_exception)}"
- ) from last_exception
+ response.raise_for_status()
+ result = {}
+ if stream_mode:
+ # 将流式输出转化为非流式输出
+ result = await self._handle_stream_output(response)
else:
- if not available_keys_pool and keys_abandoned_runtime:
- final_error_msg = f"模型 {self.model_name}: 所有可用 API Keys 均因 403 错误被禁用。"
- logger.critical(final_error_msg)
- raise PermissionDeniedException(final_error_msg)
- else:
- raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API 请求失败,原因未知。")
+ result = await response.json()
+ return (
+ response_handler(result)
+ if response_handler
+ else self._default_response_handler(result, user_id, request_type, endpoint)
+ )
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
- """处理 OpenAI 兼容的流式输出"""
- # (代码不变)
flag_delta_content_finished = False
accumulated_content = ""
- usage = None
+ usage = None # 初始化usage变量,避免未定义错误
reasoning_content = ""
content = ""
- tool_calls = None
+ tool_calls = None # 初始化工具调用变量
async for line_bytes in response.content:
try:
@@ -731,7 +362,7 @@ class LLMRequest:
if flag_delta_content_finished:
chunk_usage = chunk.get("usage", None)
if chunk_usage:
- usage = chunk_usage
+ usage = chunk_usage # 获取token用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
@@ -739,35 +370,15 @@ class LLMRequest:
delta_content = ""
accumulated_content += delta_content
+ # 提取工具调用信息
if "tool_calls" in delta:
if tool_calls is None:
- tool_calls = []
- for tc in delta["tool_calls"]:
- new_tc = dict(tc)
- if "function" in new_tc and "arguments" not in new_tc["function"]:
- new_tc["function"]["arguments"] = ""
- tool_calls.append(new_tc)
+ tool_calls = delta["tool_calls"]
else:
- for i, tc_delta in enumerate(delta["tool_calls"]):
- if (
- i < len(tool_calls)
- and "function" in tc_delta
- and "arguments" in tc_delta["function"]
- ):
- if "arguments" in tool_calls[i]["function"]:
- tool_calls[i]["function"]["arguments"] += tc_delta["function"][
- "arguments"
- ]
- else:
- tool_calls[i]["function"]["arguments"] = tc_delta["function"][
- "arguments"
- ]
- elif i >= len(tool_calls):
- new_tc = dict(tc_delta)
- if "function" in new_tc and "arguments" not in new_tc["function"]:
- new_tc["function"]["arguments"] = ""
- tool_calls.append(new_tc)
+ # 合并工具调用信息
+ tool_calls.extend(delta["tool_calls"])
+ # 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if delta.get("reasoning_content", None):
reasoning_content += delta["reasoning_content"]
@@ -776,37 +387,37 @@ class LLMRequest:
if chunk_usage:
usage = chunk_usage
break
+ # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
flag_delta_content_finished = True
- except json.JSONDecodeError as e:
- logger.error(f"模型 {self.model_name} 解析流式 JSON 错误: {e} - data: '{data_str}'")
except Exception as e:
- logger.exception(f"模型 {self.model_name} 解析流式输出块错误: {str(e)}")
- except UnicodeDecodeError as e:
- logger.warning(f"模型 {self.model_name} 流式输出解码错误: {e} - bytes: {line_bytes[:50]}...")
+ logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
except Exception as e:
if isinstance(e, GeneratorExit):
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
else:
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
logger.warning(log_content)
+ # 确保资源被正确清理
try:
await response.release()
except Exception as cleanup_error:
logger.error(f"清理资源时发生错误: {cleanup_error}")
+ # 返回已经累积的内容
content = accumulated_content
- break
- if not content and accumulated_content:
+ if not content:
content = accumulated_content
think_match = re.search(r"(.*?) ", content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r".*? ", "", content, flags=re.DOTALL).strip()
+ # 构建消息对象
message = {
"content": content,
"reasoning_content": reasoning_content,
}
+ # 如果有工具调用,添加到消息中
if tool_calls:
message["tool_calls"] = tool_calls
@@ -817,407 +428,285 @@ class LLMRequest:
return result
async def _handle_error_response(
- self, response: ClientResponse, retry_count: int, policy: Dict[str, Any], current_key: str = None
- ) -> None:
- """处理 HTTP 错误响应 (区分 403 和其他错误)"""
- # (代码不变)
- status = response.status
- try:
- error_text = await response.text()
- except Exception as e:
- error_text = f"(无法读取响应体: {e})"
-
- if status == 403:
- logger.error(
- f"模型 {self.model_name}: 遇到 403 (权限拒绝) 错误。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
- f"响应: {error_text[:200]}"
- )
- raise PermissionDeniedException(f"模型禁止访问 ({status})", key_identifier=current_key)
-
- elif status in policy["retry_codes"] and status != 429:
- if status == 413:
- logger.warning(
- f"模型 {self.model_name}: 错误码 413 (Payload Too Large)。Key: ...{current_key[-4:] if current_key else 'N/A'}. 尝试压缩..."
- )
+ self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
+ ) -> Union[Dict[str, any]]:
+ if response.status in policy["retry_codes"]:
+ wait_time = policy["base_wait"] * (2**retry_count)
+ logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
+ if response.status == 413:
+ logger.warning("请求体过大,尝试压缩...")
raise PayLoadTooLargeError("请求体过大")
- elif status in [500, 503]:
+ elif response.status in [500, 503]:
logger.error(
- f"模型 {self.model_name}: 服务器内部错误或过载 ({status})。Key: ...{current_key[-4:] if current_key else 'N/A'}. "
- f"响应: {error_text[:200]}"
+ f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
- return
+ raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
else:
- logger.warning(
- f"模型 {self.model_name}: 遇到可重试错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}"
- )
- return
-
- elif status in policy["abort_codes"]:
- logger.error(
- f"模型 {self.model_name}: 遇到需要中止的错误码: {status} - {error_code_mapping.get(status, '未知错误')}. "
- f"Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
- )
- raise RequestAbortException(f"请求出现错误 {status},中止处理", response)
- else:
- logger.error(
- f"模型 {self.model_name}: 遇到未明确处理的错误码: {status}. Key: ...{current_key[-4:] if current_key else 'N/A'}. 响应: {error_text[:200]}"
- )
- try:
- response.raise_for_status()
- raise RequestAbortException(f"未处理的错误状态码 {status}", response)
- except aiohttp.ClientResponseError as e:
- raise RequestAbortException(f"未处理的错误状态码 {status}: {e.message}", response) from e
+ logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
+ raise RuntimeError("请求限制(429)")
+ elif response.status in policy["abort_codes"]:
+ if response.status != 403:
+ raise RequestAbortException("请求出现错误,中断处理", response)
+ else:
+ raise PermissionDeniedException("模型禁止访问")
async def _handle_exception(
- self, exception, retry_count: int, request_content: Dict[str, Any], merged_params: Dict[str, Any] = None
+ self, exception, retry_count: int, request_content: Dict[str, Any]
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
- """处理非 HTTP 错误,支持使用合并后的参数重建 payload"""
policy = request_content["policy"]
payload = request_content["payload"]
- _wait_time = policy["base_wait"] * (2**retry_count)
+ wait_time = policy["base_wait"] * (2**retry_count)
keep_request = False
if retry_count < policy["max_retries"] - 1:
keep_request = True
-
- params_for_rebuild = merged_params if merged_params is not None else payload
-
- if isinstance(exception, PayLoadTooLargeError):
- if keep_request:
- logger.warning("请求体过大 (PayLoadTooLargeError),尝试压缩图片...")
- image_base64 = request_content.get("image_base64")
- if image_base64:
- compressed_image_base64 = compress_base64_image_by_scale(image_base64)
- if compressed_image_base64 != image_base64:
- new_payload = await self._build_payload(
- request_content["prompt"],
- compressed_image_base64,
- request_content["image_format"],
- params_for_rebuild,
- )
- logger.info("图片压缩成功,将使用压缩后的图片重试。")
- return new_payload, 0
- else:
- logger.warning("图片压缩未改变大小或失败。")
+ if isinstance(exception, RequestAbortException):
+ response = exception.response
+ logger.error(
+ f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
+ )
+ # 尝试获取并记录服务器返回的详细错误信息
+ try:
+ error_json = await response.json()
+ if error_json and isinstance(error_json, list) and len(error_json) > 0:
+ # 处理多个错误的情况
+ for error_item in error_json:
+ if "error" in error_item and isinstance(error_item["error"], dict):
+ error_obj: dict = error_item["error"]
+ error_code = error_obj.get("code")
+ error_message = error_obj.get("message")
+ error_status = error_obj.get("status")
+ logger.error(
+ f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
+ )
+ elif isinstance(error_json, dict) and "error" in error_json:
+ # 处理单个错误对象的情况
+ error_obj = error_json.get("error", {})
+ error_code = error_obj.get("code")
+ error_message = error_obj.get("message")
+ error_status = error_obj.get("status")
+ logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
else:
- logger.warning("请求体过大但请求中不包含图片,无法压缩。")
- return None, 0
- else:
- logger.error("达到最大重试次数,请求体仍然过大。")
- raise RuntimeError("请求体过大,压缩或重试后仍然失败。") from exception
+ # 记录原始错误响应内容
+ logger.error(f"服务器错误响应: {error_json}")
+ except Exception as e:
+ logger.warning(f"无法解析服务器错误响应: {str(e)}")
+ raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
- elif isinstance(exception, (aiohttp.ClientError, asyncio.TimeoutError)):
+ elif isinstance(exception, PermissionDeniedException):
+ # 只针对硅基流动的V3和R1进行降级处理
+ if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
+ old_model_name = self.model_name
+ self.model_name = self.model_name[4:] # 移除"Pro/"前缀
+ logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
+
+ # 对全局配置进行更新
+ if global_config.llm_normal.get("name") == old_model_name:
+ global_config.llm_normal["name"] = self.model_name
+ logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
+ if global_config.llm_reasoning.get("name") == old_model_name:
+ global_config.llm_reasoning["name"] = self.model_name
+ logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
+
+ if payload and "model" in payload:
+ payload["model"] = self.model_name
+
+ await asyncio.sleep(wait_time)
+ return payload, -1
+ raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
+
+ elif isinstance(exception, PayLoadTooLargeError):
if keep_request:
- logger.error(f"模型 {self.model_name} 网络错误: {str(exception)}")
+ image_base64 = request_content["image_base64"]
+ compressed_image_base64 = compress_base64_image_by_scale(image_base64)
+ new_payload = await self._build_payload(
+ request_content["prompt"], compressed_image_base64, request_content["image_format"]
+ )
+ return new_payload, 0
+ else:
+ return None, 0
+
+ elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
+ if keep_request:
+ logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
+ await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
- raise RuntimeError(f"网络请求失败: {str(exception)}") from exception
+ raise RuntimeError(f"网络请求失败: {str(exception)}")
elif isinstance(exception, aiohttp.ClientResponseError):
+ # 处理aiohttp抛出的,除了policy中的status的响应错误
if keep_request:
logger.error(
- f"模型 {self.model_name} HTTP响应错误 (未被策略覆盖): 状态码: {exception.status}, 错误: {exception.message}"
+ f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
)
try:
- error_text = await exception.response.text() if hasattr(exception, "response") else str(exception)
- logger.error(f"服务器错误响应详情: {error_text[:500]}")
+ error_text = await exception.response.text()
+ error_json = json.loads(error_text)
+ if isinstance(error_json, list) and len(error_json) > 0:
+ # 处理多个错误的情况
+ for error_item in error_json:
+ if "error" in error_item and isinstance(error_item["error"], dict):
+ error_obj = error_item["error"]
+ logger.error(
+ f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
+ f"状态={error_obj.get('status')}, "
+ f"消息={error_obj.get('message')}"
+ )
+ elif isinstance(error_json, dict) and "error" in error_json:
+ error_obj = error_json.get("error", {})
+ logger.error(
+ f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
+ f"状态={error_obj.get('status')}, "
+ f"消息={error_obj.get('message')}"
+ )
+ else:
+ logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
+ except (json.JSONDecodeError, TypeError) as json_err:
+ logger.warning(
+ f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
+ )
except Exception as parse_err:
- logger.warning(f"无法解析服务器错误响应内容: {str(parse_err)}")
+ logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
+
+ await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
)
- current_key_placeholder = request_content.get("current_key", "******")
+ # 安全地检查和记录请求详情
handled_payload = await _safely_record(request_content, payload)
- logger.critical(
- f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
- )
+ logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
- ) from exception
+ )
else:
if keep_request:
- logger.error(
- f"模型 {self.model_name} 遇到未知错误: {str(exception.__class__.__name__)} - {str(exception)}"
- )
+ logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
+ await asyncio.sleep(wait_time)
return None, 0
else:
- logger.critical(
- f"模型 {self.model_name} 请求因未知错误失败: {str(exception.__class__.__name__)} - {str(exception)}"
- )
- current_key_placeholder = request_content.get("current_key", "******")
+ logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
+ # 安全地检查和记录请求详情
handled_payload = await _safely_record(request_content, payload)
- logger.critical(
- f"请求头: {await self._build_headers(api_key=current_key_placeholder, no_key=True)} 请求体: {handled_payload}"
- )
- raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") from exception
+ logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
+ raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
- async def _transform_parameters(self, merged_params: dict) -> dict:
- """根据模型名称转换合并后的参数,并移除内部参数"""
- # (代码不变)
- new_params = dict(merged_params)
- new_params.pop("request_type", None)
+ async def _transform_parameters(self, params: dict) -> dict:
+ """
+ 根据模型名称转换参数:
+ - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
+ 并将 'max_tokens' 重命名为 'max_completion_tokens'
+ """
+ # 复制一份参数,避免直接修改原始数据
+ new_params = dict(params)
- if not self.is_gemini and self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
+ if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
+ # 删除 'temperature' 参数(如果存在)
new_params.pop("temperature", None)
+ # 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
if "max_tokens" in new_params:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
- elif self.is_gemini:
- gen_config = new_params.get("generationConfig", {})
- if "temperature" in new_params:
- gen_config["temperature"] = new_params.pop("temperature")
- if "max_tokens" in new_params:
- gen_config["maxOutputTokens"] = new_params.pop("max_tokens")
- if "top_p" in new_params:
- gen_config["topP"] = new_params.pop("top_p")
- if "top_k" in new_params:
- gen_config["topK"] = new_params.pop("top_k")
-
- if gen_config:
- new_params["generationConfig"] = gen_config
-
- new_params.pop("frequency_penalty", None)
- new_params.pop("presence_penalty", None)
- new_params.pop("max_completion_tokens", None)
-
return new_params
- async def _build_payload(
- self, prompt: str, image_base64: str = None, image_format: str = None, merged_params: dict = None
- ) -> dict:
- """构建请求体 (区分 Gemini 和 OpenAI),使用合并和转换后的参数"""
- # (代码不变)
- if merged_params is None:
- merged_params = self.params
-
- params_copy = await self._transform_parameters(merged_params)
-
- if self.is_gemini:
- parts = []
- if prompt:
- parts.append({"text": prompt})
- if image_base64:
- mime_type = f"image/{image_format.lower() if image_format else 'jpeg'}"
- parts.append({"inlineData": {"mimeType": mime_type, "data": image_base64}})
- payload = {"contents": [{"parts": parts}], **params_copy}
- payload.pop("model", None)
- # --- 添加 Gemini 安全设置 ---
- safety_settings = [
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
- {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
+ async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
+ """构建请求体"""
+ # 复制一份参数,避免直接修改 self.params
+ params_copy = await self._transform_parameters(self.params)
+ if image_base64:
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
+ },
+ ],
+ }
]
- payload["safetySettings"] = safety_settings
- logger.debug(f"模型 {self.model_name}: 已为 Gemini 函数调用请求添加 safetySettings (BLOCK_NONE)。")
- # --- 结束添加安全设置 ---
-
else:
- if image_base64:
- messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": prompt},
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64}"
- },
- },
- ],
- }
- ]
- else:
- messages = [{"role": "user", "content": prompt}]
-
- payload = {
- "model": self.model_name,
- "messages": messages,
- **params_copy,
- }
- if "max_tokens" not in payload and "max_completion_tokens" not in payload:
- if "max_tokens" not in params_copy and "max_completion_tokens" not in params_copy:
- payload["max_tokens"] = global_config.model_max_output_length
- if "max_completion_tokens" in payload:
- payload["max_tokens"] = payload.pop("max_completion_tokens")
-
+ messages = [{"role": "user", "content": prompt}]
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ **params_copy,
+ }
+ if "max_tokens" not in payload and "max_completion_tokens" not in payload:
+ payload["max_tokens"] = global_config.model_max_output_length
+ # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
+ if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
+ payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload
def _default_response_handler(
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
) -> Tuple:
- """默认响应解析 (区分 Gemini 和 OpenAI),并处理函数/工具调用"""
- content = "没有返回结果"
- reasoning_content = ""
- tool_calls = None # OpenAI 格式
- function_call = None # Gemini 格式
- prompt_tokens = 0
- completion_tokens = 0
- total_tokens = 0
+ """默认响应解析"""
+ if "choices" in result and result["choices"]:
+ message = result["choices"][0]["message"]
+ content = message.get("content", "")
+ content, reasoning = self._extract_reasoning(content)
+ reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
+ if not reasoning_content:
+ reasoning_content = message.get("reasoning_content", "")
+ if not reasoning_content:
+ reasoning_content = reasoning
- if self.is_gemini:
- # --- 解析 Gemini 响应 ---
- try:
- if "candidates" in result and result["candidates"]:
- candidate = result["candidates"][0]
- # 检查是否有 content 和 parts
- if "content" in candidate and "parts" in candidate["content"] and candidate["content"]["parts"]:
- # 查找 functionCall 或 text 部分
- final_text_parts = []
- for part in candidate["content"]["parts"]:
- if "functionCall" in part:
- function_call = part["functionCall"] # 获取 Gemini 的 functionCall
- # Gemini functionCall 通常不与 text 一起返回,这里假设只处理 functionCall
- break # 找到 functionCall 就停止处理 parts
- elif "text" in part:
- final_text_parts.append(part.get("text", ""))
+ # 提取工具调用信息
+ tool_calls = message.get("tool_calls", None)
- if not function_call: # 如果没有 functionCall,处理 text
- raw_content = "".join(final_text_parts).strip()
- content, reasoning = self._extract_reasoning(raw_content)
- reasoning_content = reasoning
- # else: function_call 已获取,content 留空或设为特定值
+ # 记录token使用情况
+ usage = result.get("usage", {})
+ if usage:
+ prompt_tokens = usage.get("prompt_tokens", 0)
+ completion_tokens = usage.get("completion_tokens", 0)
+ total_tokens = usage.get("total_tokens", 0)
+ self._record_usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ user_id=user_id,
+ request_type=request_type if request_type is not None else self.request_type,
+ endpoint=endpoint,
+ )
- else:
- content = "Gemini响应中缺少 content 或 parts"
- logger.warning(f"模型 {self.model_name}: Gemini 响应格式不完整 (缺少 content/parts): {result}")
-
- finish_reason = candidate.get("finishReason")
- if finish_reason == "SAFETY":
- logger.warning(f"模型 {self.model_name}: Gemini 响应因安全设置被阻止。")
- content = "响应内容因安全原因被过滤。"
- elif finish_reason == "RECITATION":
- logger.warning(f"模型 {self.model_name}: Gemini 响应因引用限制被阻止。")
- content = "响应内容因引用限制被过滤。"
- elif finish_reason == "OTHER":
- logger.warning(f"模型 {self.model_name}: Gemini 响应因未知原因停止。")
- # finishReason == "TOOL_CODE" or "FUNCTION_CALL" 是正常情况
-
- usage = result.get("usageMetadata", {})
- if usage:
- prompt_tokens = usage.get("promptTokenCount", 0)
- completion_tokens = usage.get("candidatesTokenCount", 0)
- total_tokens = usage.get("totalTokenCount", 0)
- if completion_tokens == 0 and total_tokens > 0:
- completion_tokens = total_tokens - prompt_tokens
- else:
- logger.warning(f"模型 {self.model_name} (Gemini) 的响应中缺少 'usageMetadata' 信息。")
-
- except Exception as e:
- logger.error(f"解析 Gemini 响应出错: {e} - 响应: {result}")
- content = "解析 Gemini 响应时出错"
-
- else:
- # --- 解析 OpenAI 兼容响应 ---
- # (代码不变)
- if "choices" in result and result["choices"]:
- message = result["choices"][0].get("message", {})
- raw_content = message.get("content", "")
- content, reasoning = self._extract_reasoning(raw_content if raw_content else "")
-
- explicit_reasoning = message.get("model_extra", {}).get("reasoning_content", "")
- if not explicit_reasoning:
- explicit_reasoning = message.get("reasoning_content", "")
- reasoning_content = explicit_reasoning if explicit_reasoning else reasoning
-
- tool_calls = message.get("tool_calls", None) # 获取 OpenAI 的 tool_calls
-
- usage = result.get("usage", {})
- if usage:
- prompt_tokens = usage.get("prompt_tokens", 0)
- completion_tokens = usage.get("completion_tokens", 0)
- total_tokens = usage.get("total_tokens", 0)
- else:
- logger.warning(f"模型 {self.model_name} (OpenAI) 的响应中缺少 'usage' 信息。")
+ # 只有当tool_calls存在且不为空时才返回
+ if tool_calls:
+ logger.debug(f"检测到工具调用: {tool_calls}")
+ return content, reasoning_content, tool_calls
else:
- logger.warning(f"模型 {self.model_name} (OpenAI) 的响应格式不符合预期: {result}")
+ return content, reasoning_content
- # --- 记录 Token 使用情况 ---
- # (代码不变)
- if prompt_tokens > 0 or completion_tokens > 0 or total_tokens > 0:
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id=user_id,
- request_type=request_type,
- endpoint=endpoint,
- )
- else:
- logger.warning(f"模型 {self.model_name}: 未能从响应中提取有效的 token 使用信息。")
-
- # --- 返回结果 (统一格式) ---
- final_tool_calls = None
- if tool_calls: # 来自 OpenAI
- final_tool_calls = tool_calls
- logger.debug(f"检测到 OpenAI 工具调用: {final_tool_calls}")
- elif function_call: # 来自 Gemini
- logger.debug(f"检测到 Gemini 函数调用: {function_call}")
- # 将 Gemini functionCall 转换为 OpenAI tool_calls 格式
- # 注意: Gemini 的 functionCall 没有显式的 id 和 type,需要模拟
- final_tool_calls = [
- {
- "id": f"call_{random.randint(1000, 9999)}", # 生成一个随机 ID
- "type": "function",
- "function": {
- "name": function_call.get("name"),
- # Gemini 的参数在 'args' 中,OpenAI 在 'arguments' (通常是 JSON 字符串)
- # 需要将 Gemini 的 dict 参数转换为 JSON 字符串
- "arguments": json.dumps(function_call.get("args", {})),
- },
- }
- ]
- logger.debug(f"转换为 OpenAI tool_calls 格式: {final_tool_calls}")
-
- if final_tool_calls:
- # 如果有工具/函数调用,通常 content 为空或包含思考过程,这里返回转换后的调用信息
- return content, reasoning_content, final_tool_calls
- else:
- # 没有工具/函数调用,返回普通文本响应
- return content, reasoning_content
+ return "没有返回结果", ""
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取"""
- # (代码不变)
- if not content:
- return "", ""
- match = re.search(r"(.*?) ", content, re.DOTALL)
- cleaned_content = re.sub(r".*? ", "", content, flags=re.DOTALL, count=1).strip()
+ match = re.search(r"(?:)?(.*?) ", content, re.DOTALL)
+ content = re.sub(r"(?:)?.*? ", "", content, flags=re.DOTALL, count=1).strip()
if match:
reasoning = match.group(1).strip()
else:
reasoning = ""
- return cleaned_content, reasoning
+ return content, reasoning
- async def _build_headers(self, api_key: str, no_key: bool = False) -> dict:
- """构建请求头 (区分 Gemini 和 OpenAI)"""
- # (代码不变)
+ async def _build_headers(self, no_key: bool = False) -> dict:
+ """构建请求头"""
if no_key:
- if self.is_gemini:
- return {"x-goog-api-key": "**********", "Content-Type": "application/json"}
- else:
- return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
+ return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
else:
- if not api_key:
- logger.error(f"尝试使用无效 (空) 的 API key 为模型 {self.model_name} 构建请求头。")
- raise ValueError("无效的 API key 提供给 _build_headers。")
+ return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
+ # 防止小朋友们截图自己的key
- if self.is_gemini:
- return {"x-goog-api-key": api_key, "Content-Type": "application/json"}
- else:
- return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+ async def generate_response(self, prompt: str) -> Tuple:
+ """根据输入的提示生成模型的异步响应"""
- async def generate_response(self, prompt: str, user_id: str = "system", **kwargs) -> Tuple:
- """根据输入的提示生成模型的异步响应,支持覆盖参数"""
- endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
- response = await self._execute_request(
- endpoint=endpoint, prompt=prompt, user_id=user_id, request_type="chat", **kwargs
- )
+ response = await self._execute_request(endpoint="/chat/completions", prompt=prompt)
+ # 根据返回值的长度决定怎么处理
if len(response) == 3:
content, reasoning_content, tool_calls = response
return content, reasoning_content, self.model_name, tool_calls
@@ -1225,292 +714,176 @@ class LLMRequest:
content, reasoning_content = response
return content, reasoning_content, self.model_name
- async def generate_response_for_image(
- self, prompt: str, image_base64: str, image_format: str, user_id: str = "system", **kwargs
- ) -> Tuple:
- """根据输入的提示和图片生成模型的异步响应,支持覆盖参数"""
- endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
- response = await self._execute_request(
- endpoint=endpoint,
- prompt=prompt,
- image_base64=image_base64,
- image_format=image_format,
- user_id=user_id,
- request_type="vision",
- **kwargs,
- )
- # _default_response_handler 现在总是返回至少2个值
- if len(response) == 3:
- return response # content, reasoning, tool_calls (tool_calls 可能为 None)
- elif len(response) == 2:
- content, reasoning = response
- return content, reasoning # 对于 vision 请求,通常没有 tool_calls
- else:
- logger.error(f"来自 _default_response_handler 的意外响应格式: {response}")
- return "处理响应出错", ""
+ async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple:
+ """根据输入的提示和图片生成模型的异步响应"""
- async def generate_response_async(
- self, prompt: str, user_id: str = "system", request_type: str = "chat", **kwargs
- ) -> Union[str, Tuple]:
- """异步方式根据输入的提示生成模型的响应 (通用),支持覆盖参数"""
- # (代码不变)
- endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
response = await self._execute_request(
- endpoint=endpoint,
- prompt=prompt,
- payload=None,
- retry_policy=None,
- response_handler=None,
- user_id=user_id,
- request_type=request_type,
- **kwargs,
+ endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format
)
+ # 根据返回值的长度决定怎么处理
+ if len(response) == 3:
+ content, reasoning_content, tool_calls = response
+ return content, reasoning_content, tool_calls
+ else:
+ content, reasoning_content = response
+ return content, reasoning_content
+
+ async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
+ """异步方式根据输入的提示生成模型的响应"""
+ # 构建请求体,不硬编码max_tokens
+ data = {
+ "model": self.model_name,
+ "messages": [{"role": "user", "content": prompt}],
+ **self.params,
+ **kwargs,
+ }
+
+ response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
+ # 原样返回响应,不做处理
+
return response
- # 修改:实现 Gemini Function Calling 的 Payload 构建
- async def generate_response_tool_async(
- self, prompt: str, tools: list, user_id: str = "system", **kwargs
- ) -> tuple[str, str, list | None]:
- """异步方式根据输入的提示和工具生成模型的响应,支持覆盖参数和 Gemini 函数调用"""
-
- endpoint = ":generateContent" if self.is_gemini else "/chat/completions"
- merged_params = {**self.params, **kwargs}
- transformed_params = await self._transform_parameters(merged_params) # 清理 request_type 等
-
- payload = None
-
- if self.is_gemini:
- # --- 构建 Gemini Function Calling Payload ---
- logger.debug(f"为 Gemini ({self.model_name}) 构建函数调用请求。")
- # 1. 转换工具定义 (OpenAI -> Gemini)
- # OpenAI tool format: [{"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}}]
- # Gemini tool format: [{"functionDeclarations": [{"name": ..., "description": ..., "parameters": ...}]}]
- function_declarations = []
- if tools:
- for tool in tools:
- if tool.get("type") == "function" and "function" in tool:
- func_def = tool["function"]
- # Gemini parameters 使用 OpenAPI Schema,与 OpenAI 基本兼容
- function_declarations.append(
- {
- "name": func_def.get("name"),
- "description": func_def.get("description", ""), # Description is required for Gemini
- "parameters": func_def.get(
- "parameters", {"type": "object", "properties": {}}
- ), # Ensure parameters exist
- }
- )
- else:
- logger.warning(f"跳过不支持的工具类型或格式: {tool}")
-
- if not function_declarations:
- logger.error("没有有效的函数声明可用于 Gemini 请求。")
- return "没有提供有效的函数定义", "", None
-
- gemini_tools = [{"functionDeclarations": function_declarations}]
-
- # 2. 构建 Gemini Payload
- # parts = [{"text": prompt}] # 初始 parts
- payload = {
- "contents": [{"parts": [{"text": prompt}]}], # 包含用户提示
- "tools": gemini_tools,
- # toolConfig 默认是 AUTO,可以根据需要从 kwargs 获取或硬编码
- # "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, # 例如强制调用
- **transformed_params, # 合并其他转换后的参数 (如 generationConfig)
- }
- payload.pop("model", None) # Gemini 不在顶层传 model
- payload.pop("messages", None) # 移除 OpenAI 特有的 messages
- payload.pop("tool_choice", None) # 移除 OpenAI 特有的 tool_choice
-
- logger.trace(f"构建的 Gemini 函数调用 Payload: {json.dumps(payload, indent=2)}")
+ async def generate_response_tool_async(self, prompt: str, tools: list, **kwargs) -> tuple[str, str, list]:
+ """异步方式根据输入的提示生成模型的响应"""
+ # 构建请求体,不硬编码max_tokens
+ data = {
+ "model": self.model_name,
+ "messages": [{"role": "user", "content": prompt}],
+ **self.params,
+ **kwargs,
+ "tools": tools,
+ }
+ response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
+ logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具,返回结果: {response}")
+ # 检查响应是否包含工具调用
+ if len(response) == 3:
+ content, reasoning_content, tool_calls = response
+ logger.debug(f"收到工具调用响应,包含 {len(tool_calls) if tool_calls else 0} 个工具调用")
+ return content, reasoning_content, tool_calls
else:
- # --- 构建 OpenAI Tool Calling Payload ---
- # (逻辑不变)
- logger.debug(f"为 OpenAI 兼容模型 ({self.model_name}) 构建工具调用请求。")
- payload = {
- "model": self.model_name,
- "messages": [{"role": "user", "content": prompt}],
- **transformed_params,
- "tools": tools,
- "tool_choice": transformed_params.get("tool_choice", "auto"),
- }
- if "max_completion_tokens" in payload:
- payload["max_tokens"] = payload.pop("max_completion_tokens")
- if "max_tokens" not in payload:
- payload["max_tokens"] = global_config.model_max_output_length
-
- # --- 执行请求 ---
- if payload is None:
- logger.error("未能构建有效的 API 请求 payload。")
- return "内部错误:无法构建请求", "", None
-
- response = await self._execute_request(
- endpoint=endpoint,
- payload=payload,
- prompt=prompt, # prompt 仍然需要,用于可能的重试
- user_id=user_id,
- request_type="tool_call",
- **kwargs, # 传递原始 kwargs 以便在重试时重新合并
- )
-
- # _default_response_handler 现在会处理 Gemini functionCall 并统一格式
- logger.debug(f"模型 {self.model_name} 工具/函数调用返回结果: {response}")
-
- if isinstance(response, tuple) and len(response) == 3:
- content, reasoning_content, final_tool_calls = response
- # final_tool_calls 已经是统一的 OpenAI 格式
- return content, reasoning_content, final_tool_calls
- elif isinstance(response, tuple) and len(response) == 2:
content, reasoning_content = response
- logger.debug("收到普通响应,无工具/函数调用")
+ logger.debug("收到普通响应,无工具调用")
return content, reasoning_content, None
- else:
- logger.error(f"收到来自 _execute_request/_default_response_handler 的意外响应格式: {response}")
- return "处理响应时出错", "", None
- async def get_embedding(self, text: str, user_id: str = "system", **kwargs) -> Union[list, None]:
- """异步方法:获取文本的embedding向量,支持覆盖参数 (Gemini Embedding 需注意模型名称)"""
- # (代码不变)
+ async def get_embedding(self, text: str) -> Union[list, None]:
+ """异步方法:获取文本的embedding向量
+
+ Args:
+ text: 需要获取embedding的文本
+
+ Returns:
+ list: embedding向量,如果失败则返回None
+ """
+
if len(text) < 1:
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
return None
- api_kwargs = {k: v for k, v in kwargs.items() if k != "request_type"}
-
- if self.is_gemini:
- endpoint = ":embedContent"
- payload = {"model": f"models/{self.model_name}", "content": {"parts": [{"text": text}]}, **api_kwargs}
- payload.pop("encoding_format", None)
- payload.pop("input", None)
-
- else:
- endpoint = "/embeddings"
- payload = {"model": self.model_name, "input": text, "encoding_format": "float", **api_kwargs}
- payload.pop("content", None)
- payload.pop("taskType", None)
-
def embedding_handler(result):
- # (代码不变)
- embedding_value = None
- prompt_tokens = 0
- completion_tokens = 0
- total_tokens = 0
-
- if self.is_gemini:
- if "embedding" in result and "value" in result["embedding"]:
- embedding_value = result["embedding"]["value"]
- logger.warning(f"模型 {self.model_name} (Gemini Embedding): 响应中未找到明确的 token 使用信息。")
- else:
- if "data" in result and len(result["data"]) > 0:
- embedding_value = result["data"][0].get("embedding", None)
+ """处理响应"""
+ if "data" in result and len(result["data"]) > 0:
+ # 提取 token 使用信息
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
+ completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
- else:
- logger.warning(f"模型 {self.model_name} (OpenAI Embedding) 的响应中缺少 'usage' 信息。")
-
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id=user_id,
- request_type="embedding",
- endpoint=endpoint,
- )
- return embedding_value
+ # 记录 token 使用情况
+ self._record_usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ user_id="system", # 可以根据需要修改 user_id
+ # request_type="embedding", # 请求类型为 embedding
+ request_type=self.request_type, # 请求类型为 text
+ endpoint="/embeddings", # API 端点
+ )
+ return result["data"][0].get("embedding", None)
+ return result["data"][0].get("embedding", None)
+ return None
embedding = await self._execute_request(
- endpoint=endpoint,
- payload=payload,
+ endpoint="/embeddings",
prompt=text,
+ payload={"model": self.model_name, "input": text, "encoding_format": "float"},
retry_policy={"max_retries": 2, "base_wait": 6},
response_handler=embedding_handler,
- user_id=user_id,
- request_type="embedding",
- **api_kwargs,
)
return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
- """压缩base64格式的图片到指定大小"""
- # (代码不变)
+ """压缩base64格式的图片到指定大小
+ Args:
+ base64_data: base64编码的图片数据
+ target_size: 目标文件大小(字节),默认0.8MB
+ Returns:
+ str: 压缩后的base64图片数据
+ """
try:
+ # 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
- if len(image_data) <= target_size * 1.05:
- logger.info(f"图片大小 {len(image_data) / 1024:.1f}KB 已足够小,无需压缩。")
- return base64_data
- img = Image.open(io.BytesIO(image_data))
- img_format = img.format
- original_width, original_height = img.size
- scale = max(0.2, min(1.0, (target_size / len(image_data)) ** 0.5))
- new_width = max(1, int(original_width * scale))
- new_height = max(1, int(original_height * scale))
- output_buffer = io.BytesIO()
- save_format = img_format # Default to original format
- if getattr(img, "is_animated", False) and img.n_frames > 1:
+ # 如果已经小于目标大小,直接返回原图
+ if len(image_data) <= 2 * 1024 * 1024:
+ return base64_data
+
+ # 将字节数据转换为图片对象
+ img = Image.open(io.BytesIO(image_data))
+
+ # 获取原始尺寸
+ original_width, original_height = img.size
+
+ # 计算缩放比例
+ scale = min(1.0, (target_size / len(image_data)) ** 0.5)
+
+ # 计算新的尺寸
+ new_width = int(original_width * scale)
+ new_height = int(original_height * scale)
+
+ # 创建内存缓冲区
+ output_buffer = io.BytesIO()
+
+ # 如果是GIF,处理所有帧
+ if getattr(img, "is_animated", False):
frames = []
- durations = []
- loop = img.info.get("loop", 0)
- disposal = img.info.get("disposal", 2)
- logger.info(f"检测到 GIF 动图 ({img.n_frames} 帧),尝试按比例压缩...")
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
- current_duration = img.info.get("duration", 100)
- durations.append(current_duration)
- new_frame = img.convert("RGBA").copy()
- resized_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
- frames.append(resized_frame)
- if frames:
- frames[0].save(
- output_buffer,
- format="GIF",
- save_all=True,
- append_images=frames[1:],
- optimize=False,
- duration=durations,
- loop=loop,
- disposal=disposal,
- transparency=img.info.get("transparency", None),
- background=img.info.get("background", None),
- )
- save_format = "GIF"
- else:
- logger.warning("未能处理 GIF 帧。")
- return base64_data
- else:
- if img.mode in ("RGBA", "LA") or "transparency" in img.info:
- resized_img = img.convert("RGBA").resize((new_width, new_height), Image.Resampling.LANCZOS)
- save_format = "PNG"
- save_params = {"optimize": True}
- else:
- resized_img = img.convert("RGB").resize((new_width, new_height), Image.Resampling.LANCZOS)
- if img_format and img_format.upper() == "JPEG":
- save_format = "JPEG"
- save_params = {"quality": 85, "optimize": True}
- else:
- save_format = "PNG"
- save_params = {"optimize": True}
- resized_img.save(output_buffer, format=save_format, **save_params)
+ new_frame = img.copy()
+ new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
+ frames.append(new_frame)
- compressed_data = output_buffer.getvalue()
- logger.success(
- f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height} ({img.format} -> {save_format})"
- )
- logger.info(
- f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB (目标: {target_size / 1024:.1f}KB)"
- )
- if len(compressed_data) < len(image_data) * 0.95:
- return base64.b64encode(compressed_data).decode("utf-8")
+ # 保存到缓冲区
+ frames[0].save(
+ output_buffer,
+ format="GIF",
+ save_all=True,
+ append_images=frames[1:],
+ optimize=True,
+ duration=img.info.get("duration", 100),
+ loop=img.info.get("loop", 0),
+ )
else:
- logger.info("压缩效果不明显或反而增大,返回原始图片。")
- return base64_data
+ # 处理静态图片
+ resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
+
+ # 保存到缓冲区,保持原始格式
+ if img.format == "PNG" and img.mode in ("RGBA", "LA"):
+ resized_img.save(output_buffer, format="PNG", optimize=True)
+ else:
+ resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True)
+
+ # 获取压缩后的数据并转换为base64
+ compressed_data = output_buffer.getvalue()
+ logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
+ logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB")
+
+ return base64.b64encode(compressed_data).decode("utf-8")
+
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
- return base64_data
+ return base64_data
\ No newline at end of file
diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py
index 41c18593..bc65776e 100644
--- a/src/chat/person_info/person_info.py
+++ b/src/chat/person_info/person_info.py
@@ -205,7 +205,7 @@ class PersonInfoManager:
existing_names = ""
while current_try < max_retries:
individuality = Individuality.get_instance()
- prompt_personality = individuality.get_prompt(x_person=2, level=3)
+ prompt_personality = individuality.get_prompt(x_person=2, level=1)
bot_name = individuality.personality.bot_nickname
qv_name_prompt = f"你是{bot_name},{prompt_personality}"
diff --git a/src/chat/person_info/relationship_manager.py b/src/chat/person_info/relationship_manager.py
index e1e611cc..3b873f50 100644
--- a/src/chat/person_info/relationship_manager.py
+++ b/src/chat/person_info/relationship_manager.py
@@ -313,7 +313,7 @@ class RelationshipManager:
value = self.mood_feedback(value)
level_num = self.calculate_level_num(old_value + value)
- relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "依赖"]
+ relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
logger.info(
f"用户: {user_info.user_nickname}"
f"当前关系: {relationship_level[level_num]}, "
@@ -400,7 +400,7 @@ class RelationshipManager:
value = self.mood_feedback(value)
level_num = self.calculate_level_num(old_value + value)
- relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "依赖"]
+ relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
logger.info(
f"用户: {chat_stream.user_info.user_nickname}"
f"当前关系: {relationship_level[level_num]}, "
@@ -425,7 +425,7 @@ class RelationshipManager:
level_num = self.calculate_level_num(relationship_value)
if level_num == 0 or level_num == 5:
- relationship_level = ["厌恶", "冷漠以对", "认识", "友好对待", "喜欢", "依赖"]
+ relationship_level = ["厌恶", "冷漠以对", "认识", "友好对待", "喜欢", "暧昧"]
relation_prompt2_list = [
"忽视的回应",
"冷淡回复",
@@ -439,7 +439,7 @@ class RelationshipManager:
return ""
else:
if random.random() < 0.6:
- relationship_level = ["厌恶", "冷漠以对", "认识", "友好对待", "喜欢", "依赖"]
+ relationship_level = ["厌恶", "冷漠以对", "认识", "友好对待", "喜欢", "暧昧"]
relation_prompt2_list = [
"忽视的回应",
"冷淡回复",
diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py
index 1b9d1f14..5fa76deb 100644
--- a/src/chat/utils/statistic.py
+++ b/src/chat/utils/statistic.py
@@ -69,19 +69,9 @@ class OnlineTimeRecordTask(AsyncTask):
else:
# 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now()
- recent_record = db.online_time.find_one(
+ if recent_record := db.online_time.find_one(
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
- )
-
- if not recent_record:
- # 若没有记录,则插入新的在线时间记录
- self.record_id = db.online_time.insert_one(
- {
- "start_timestamp": current_time,
- "end_timestamp": current_time + timedelta(minutes=1),
- }
- ).inserted_id
- else:
+ ):
# 如果有记录,则更新结束时间
self.record_id = recent_record["_id"]
db.online_time.update_one(
@@ -92,8 +82,16 @@ class OnlineTimeRecordTask(AsyncTask):
}
},
)
- except Exception:
- logger.exception("在线时间记录失败")
+ else:
+ # 若没有记录,则插入新的在线时间记录
+ self.record_id = db.online_time.insert_one(
+ {
+ "start_timestamp": current_time,
+ "end_timestamp": current_time + timedelta(minutes=1),
+ }
+ ).inserted_id
+ except Exception as e:
+ logger.error(f"在线时间记录失败,错误信息:{e}")
def _format_online_time(online_seconds: int) -> str:
@@ -102,7 +100,7 @@ def _format_online_time(online_seconds: int) -> str:
:param online_seconds: 在线时间(秒)
:return: 格式化后的在线时间字符串
"""
- total_oneline_time = timedelta(seconds=int(online_seconds)) # 确保是整数
+ total_oneline_time = timedelta(seconds=online_seconds)
days = total_oneline_time.days
hours = total_oneline_time.seconds // 3600
@@ -110,15 +108,13 @@ def _format_online_time(online_seconds: int) -> str:
seconds = total_oneline_time.seconds % 60
if days > 0:
# 如果在线时间超过1天,则格式化为"X天X小时X分钟"
- total_oneline_time_str = f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒"
+ return f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒"
elif hours > 0:
# 如果在线时间超过1小时,则格式化为"X小时X分钟X秒"
- total_oneline_time_str = f"{hours}小时{minutes}分钟{seconds}秒"
+ return f"{hours}小时{minutes}分钟{seconds}秒"
else:
# 其他情况格式化为"X分钟X秒"
- total_oneline_time_str = f"{minutes}分钟{seconds}秒"
-
- return total_oneline_time_str
+ return f"{minutes}分钟{seconds}秒"
class StatisticOutputTask(AsyncTask):
@@ -141,7 +137,7 @@ class StatisticOutputTask(AsyncTask):
记录文件路径
"""
- now = datetime.now() # Renamed to avoid conflict with 'now' in methods
+ now = datetime.now()
if "deploy_time" in local_storage:
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
@@ -167,17 +163,16 @@ class StatisticOutputTask(AsyncTask):
:param now: 基准当前时间
"""
# 输出最近一小时的统计数据
- last_hour_stats = stats.get("last_hour", {}) # Ensure 'last_hour' key exists
output = [
self.SEP_LINE,
f" 最近1小时的统计数据 (自{now.strftime('%Y-%m-%d %H:%M:%S')}开始,详细信息见文件:{self.record_file_path})",
self.SEP_LINE,
- self._format_total_stat(last_hour_stats),
+ self._format_total_stat(stats["last_hour"]),
"",
- self._format_model_classified_stat(last_hour_stats),
+ self._format_model_classified_stat(stats["last_hour"]),
"",
- self._format_chat_stat(last_hour_stats),
+ self._format_chat_stat(stats["last_hour"]),
self.SEP_LINE,
"",
]
@@ -191,10 +186,7 @@ class StatisticOutputTask(AsyncTask):
stats = self._collect_all_statistics(now)
# 输出统计数据到控制台
- if "last_hour" in stats: # Check if stats for last_hour were successfully collected
- self._statistic_console_output(stats, now)
- else:
- logger.warning("无法输出最近一小时统计数据到控制台,因为数据缺失。")
+ self._statistic_console_output(stats, now)
# 输出统计数据到html文件
self._generate_html_report(stats, now)
except Exception as e:
@@ -207,29 +199,37 @@ class StatisticOutputTask(AsyncTask):
"""
收集指定时间段的LLM请求统计数据
- :param collect_period: 统计时间段 [(period_key, start_datetime), ...]
+ :param collect_period: 统计时间段
"""
- if not collect_period:
+ if len(collect_period) <= 0:
return {}
-
- collect_period.sort(key=lambda x: x[1], reverse=True)
+ else:
+ # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
+ collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
+ # 总LLM请求数
TOTAL_REQ_CNT: 0,
+ # 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
+ # 输入Token数
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
+ # 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
+ # 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
+ # 总开销
TOTAL_COST: 0.0,
+ # 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
@@ -237,54 +237,46 @@ class StatisticOutputTask(AsyncTask):
for period_key, _ in collect_period
}
- # Determine the overall earliest start time for the database query
- # This assumes collect_period is not empty, which is checked at the beginning.
- overall_earliest_start_time = min(p[1] for p in collect_period)
-
- for record in db.llm_usage.find({"timestamp": {"$gte": overall_earliest_start_time}}):
+ # 以最早的时间戳为起始时间获取记录
+ for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
record_timestamp = record.get("timestamp")
- if not isinstance(record_timestamp, datetime): # Ensure timestamp is a datetime object
- try: # Attempt conversion if it's a number (e.g. Unix timestamp)
- record_timestamp = datetime.fromtimestamp(float(record_timestamp))
- except (ValueError, TypeError):
- logger.warning(f"Skipping LLM usage record with invalid timestamp: {record.get('_id')}")
- continue
+ for idx, (_, period_start) in enumerate(collect_period):
+ if record_timestamp >= period_start:
+ # 如果记录时间在当前时间段内,则它一定在更早的时间段内
+ # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
+ for period_key, _ in collect_period[idx:]:
+ stats[period_key][TOTAL_REQ_CNT] += 1
- for idx, (_current_period_key, period_start_time) in enumerate(collect_period):
- if record_timestamp >= period_start_time:
- for period_key_to_update, _ in collect_period[idx:]:
- stats[period_key_to_update][TOTAL_REQ_CNT] += 1
+ request_type = record.get("request_type", "unknown") # 请求类型
+ user_id = str(record.get("user_id", "unknown")) # 用户ID
+ model_name = record.get("model_name", "unknown") # 模型名称
- request_type = record.get("request_type", "unknown")
- user_id = str(record.get("user_id", "unknown"))
- model_name = record.get("model_name", "unknown")
+ stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
+ stats[period_key][REQ_CNT_BY_USER][user_id] += 1
+ stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
- stats[period_key_to_update][REQ_CNT_BY_TYPE][request_type] += 1
- stats[period_key_to_update][REQ_CNT_BY_USER][user_id] += 1
- stats[period_key_to_update][REQ_CNT_BY_MODEL][model_name] += 1
+ prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
+ completion_tokens = record.get("completion_tokens", 0) # 输出Token数
+ total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
- prompt_tokens = record.get("prompt_tokens", 0)
- completion_tokens = record.get("completion_tokens", 0)
- total_tokens = prompt_tokens + completion_tokens
+ stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
+ stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
+ stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens
- stats[period_key_to_update][IN_TOK_BY_TYPE][request_type] += prompt_tokens
- stats[period_key_to_update][IN_TOK_BY_USER][user_id] += prompt_tokens
- stats[period_key_to_update][IN_TOK_BY_MODEL][model_name] += prompt_tokens
+ stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens
+ stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens
+ stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens
- stats[period_key_to_update][OUT_TOK_BY_TYPE][request_type] += completion_tokens
- stats[period_key_to_update][OUT_TOK_BY_USER][user_id] += completion_tokens
- stats[period_key_to_update][OUT_TOK_BY_MODEL][model_name] += completion_tokens
-
- stats[period_key_to_update][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
- stats[period_key_to_update][TOTAL_TOK_BY_USER][user_id] += total_tokens
- stats[period_key_to_update][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
+ stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens
+ stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
+ stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0)
- stats[period_key_to_update][TOTAL_COST] += cost
- stats[period_key_to_update][COST_BY_TYPE][request_type] += cost
- stats[period_key_to_update][COST_BY_USER][user_id] += cost
- stats[period_key_to_update][COST_BY_MODEL][model_name] += cost
- break
+ stats[period_key][TOTAL_COST] += cost
+ stats[period_key][COST_BY_TYPE][request_type] += cost
+ stats[period_key][COST_BY_USER][user_id] += cost
+ stats[period_key][COST_BY_MODEL][model_name] += cost
+ break # 取消更早时间段的判断
return stats
@@ -293,43 +285,40 @@ class StatisticOutputTask(AsyncTask):
"""
收集指定时间段的在线时间统计数据
- :param collect_period: 统计时间段 [(period_key, start_datetime), ...]
- :param now: 当前时间,用于校准end_timestamp
+ :param collect_period: 统计时间段
"""
- if not collect_period:
+ if len(collect_period) <= 0:
return {}
-
- collect_period.sort(key=lambda x: x[1], reverse=True)
+ else:
+ # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
+ collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
+ # 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _ in collect_period
}
- overall_earliest_start_time = min(p[1] for p in collect_period)
-
- for record in db.online_time.find({"end_timestamp": {"$gte": overall_earliest_start_time}}):
- record_end_timestamp: datetime = record.get("end_timestamp")
- record_start_timestamp: datetime = record.get("start_timestamp")
-
- if not isinstance(record_end_timestamp, datetime) or not isinstance(record_start_timestamp, datetime):
- logger.warning(f"Skipping online_time record with invalid timestamps: {record.get('_id')}")
- continue
-
- actual_end_timestamp = min(record_end_timestamp, now)
-
- for idx, (_current_period_key, period_start_time) in enumerate(collect_period):
- if record_start_timestamp < now and actual_end_timestamp > period_start_time:
- overlap_start = max(record_start_timestamp, period_start_time)
- overlap_end = min(actual_end_timestamp, now)
-
- if overlap_end > overlap_start:
- duration_seconds = (overlap_end - overlap_start).total_seconds()
- for period_key_to_update, _ in collect_period[idx:]:
- stats[period_key_to_update][ONLINE_TIME] += duration_seconds
- break
+ # 统计在线时间
+ for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
+ end_timestamp: datetime = record.get("end_timestamp")
+ for idx, (_, period_start) in enumerate(collect_period):
+ if end_timestamp >= period_start:
+ # 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间
+ end_timestamp = min(end_timestamp, now)
+ # 如果记录时间在当前时间段内,则它一定在更早的时间段内
+ # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
+ for period_key, _period_start in collect_period[idx:]:
+ start_timestamp: datetime = record.get("start_timestamp")
+ if start_timestamp < _period_start:
+ # 如果开始时间在查询边界之前,则使用开始时间
+ stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
+ else:
+ # 否则,使用开始时间
+ stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
+ break # 取消更早时间段的判断
return stats
@@ -337,66 +326,55 @@ class StatisticOutputTask(AsyncTask):
"""
收集指定时间段的消息统计数据
- :param collect_period: 统计时间段 [(period_key, start_datetime), ...]
+ :param collect_period: 统计时间段
"""
- if not collect_period:
+ if len(collect_period) <= 0:
return {}
-
- collect_period.sort(key=lambda x: x[1], reverse=True)
+ else:
+ # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
+ collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
+ # 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _ in collect_period
}
- overall_earliest_start_timestamp_float = min(p[1].timestamp() for p in collect_period)
+ # 统计消息量
+ for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
+ chat_info = message.get("chat_info", None) # 聊天信息
+ user_info = message.get("user_info", None) # 用户信息(消息发送人)
+ message_time = message.get("time", 0) # 消息时间
- for message in db.messages.find({"time": {"$gte": overall_earliest_start_timestamp_float}}):
- chat_info = message.get("chat_info", {})
- user_info = message.get("user_info", {})
- message_time_ts = message.get("time")
-
- if message_time_ts is None:
- logger.warning(f"Skipping message record with no timestamp: {message.get('_id')}")
- continue
-
- try:
- message_datetime = datetime.fromtimestamp(float(message_time_ts))
- except (ValueError, TypeError):
- logger.warning(f"Skipping message record with invalid time format: {message.get('_id')}")
- continue
-
- group_info = chat_info.get("group_info")
- chat_id = None
- chat_name = None
-
- if group_info and group_info.get("group_id"):
- gid = group_info.get("group_id")
- chat_id = f"g{gid}"
- chat_name = group_info.get("group_name", f"群聊 {gid}")
- elif user_info and user_info.get("user_id"):
- uid = user_info["user_id"]
- chat_id = f"u{uid}"
- chat_name = user_info.get("user_nickname", f"用户 {uid}")
-
- if not chat_id:
- continue
-
- current_mapping = self.name_mapping.get(chat_id)
- if current_mapping:
- if chat_name != current_mapping[0] and message_time_ts > current_mapping[1]:
- self.name_mapping[chat_id] = (chat_name, message_time_ts)
+ group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
+ if group_info is not None:
+ # 若有群聊信息
+ chat_id = f"g{group_info.get('group_id')}"
+ chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
+ elif user_info:
+ # 若没有群聊信息,则尝试获取用户信息
+ chat_id = f"u{user_info['user_id']}"
+ chat_name = user_info["user_nickname"]
else:
- self.name_mapping[chat_id] = (chat_name, message_time_ts)
+ continue # 如果没有群组信息也没有用户信息,则跳过
- for idx, (_current_period_key, period_start_time) in enumerate(collect_period):
- if message_datetime >= period_start_time:
- for period_key_to_update, _ in collect_period[idx:]:
- stats[period_key_to_update][TOTAL_MSG_CNT] += 1
- stats[period_key_to_update][MSG_CNT_BY_CHAT][chat_id] += 1
+ if chat_id in self.name_mapping:
+ if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
+ # 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
+ self.name_mapping[chat_id] = (chat_name, message_time)
+ else:
+ self.name_mapping[chat_id] = (chat_name, message_time)
+
+ for idx, (_, period_start) in enumerate(collect_period):
+ if message_time >= period_start.timestamp():
+ # 如果记录时间在当前时间段内,则它一定在更早的时间段内
+ # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
+ for period_key, _ in collect_period[idx:]:
+ stats[period_key][TOTAL_MSG_CNT] += 1
+ stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
return stats
@@ -406,77 +384,53 @@ class StatisticOutputTask(AsyncTask):
收集各时间段的统计数据
:param now: 基准当前时间
"""
- # Correctly determine deploy_time
- if "deploy_time" in local_storage:
- try:
- deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
- except (TypeError, ValueError):
- logger.error("Invalid deploy_time in local_storage. Resetting.")
- deploy_time = datetime(2000, 1, 1)
- local_storage["deploy_time"] = now.timestamp()
- else:
- deploy_time = datetime(2000, 1, 1)
- local_storage["deploy_time"] = now.timestamp()
- # Rebuild stat_period based on the current 'now' and determined 'deploy_time'
- current_stat_periods_config = [
- ("all_time", now - deploy_time if now > deploy_time else timedelta(seconds=0), "自部署以来"),
- ("last_7_days", timedelta(days=7), "最近7天"),
- ("last_24_hours", timedelta(days=1), "最近24小时"),
- ("last_hour", timedelta(hours=1), "最近1小时"),
- ]
- self.stat_period = current_stat_periods_config # Update instance's stat_period if needed elsewhere
+ last_all_time_stat = None
- stat_start_timestamp_config = []
- for period_name, delta, _ in current_stat_periods_config:
- start_dt = deploy_time if period_name == "all_time" else now - delta
- stat_start_timestamp_config.append((period_name, start_dt))
+ if "last_full_statistics" in local_storage:
+ # 如果存在上次完整统计数据,则使用该数据进行增量统计
+ last_stat = local_storage["last_full_statistics"] # 上次完整统计数据
- # 收集各类数据
- model_req_stat = self._collect_model_request_for_period(stat_start_timestamp_config)
- online_time_stat = self._collect_online_time_for_period(stat_start_timestamp_config, now)
- message_count_stat = self._collect_message_count_for_period(stat_start_timestamp_config)
+ self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
+ last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
+ last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳
+ self.stat_period = [item for item in self.stat_period if item[0] != "all_time"] # 删除"所有时间"的统计时段
+ self.stat_period.append(("all_time", now - last_stat_timestamp, "自部署以来的"))
- final_stats = {}
- for period_key, _ in stat_start_timestamp_config:
- final_stats[period_key] = {}
- final_stats[period_key].update(model_req_stat.get(period_key, {}))
- final_stats[period_key].update(online_time_stat.get(period_key, {}))
- final_stats[period_key].update(message_count_stat.get(period_key, {}))
+ stat_start_timestamp = [(period[0], now - period[1]) for period in self.stat_period]
- for stat_field_key in [
- TOTAL_REQ_CNT,
- REQ_CNT_BY_TYPE,
- REQ_CNT_BY_USER,
- REQ_CNT_BY_MODEL,
- IN_TOK_BY_TYPE,
- IN_TOK_BY_USER,
- IN_TOK_BY_MODEL,
- OUT_TOK_BY_TYPE,
- OUT_TOK_BY_USER,
- OUT_TOK_BY_MODEL,
- TOTAL_TOK_BY_TYPE,
- TOTAL_TOK_BY_USER,
- TOTAL_TOK_BY_MODEL,
- TOTAL_COST,
- COST_BY_TYPE,
- COST_BY_USER,
- COST_BY_MODEL,
- ONLINE_TIME,
- TOTAL_MSG_CNT,
- MSG_CNT_BY_CHAT,
- ]:
- if stat_field_key not in final_stats[period_key]:
- # Initialize with appropriate default type if key is missing
- if "BY_" in stat_field_key: # These are usually defaultdicts
- final_stats[period_key][stat_field_key] = defaultdict(
- int if "CNT" in stat_field_key or "TOK" in stat_field_key else float
- )
- elif "CNT" in stat_field_key or "TOK" in stat_field_key:
- final_stats[period_key][stat_field_key] = 0
- elif "COST" in stat_field_key or ONLINE_TIME == stat_field_key:
- final_stats[period_key][stat_field_key] = 0.0
- return final_stats
+ stat = {item[0]: {} for item in self.stat_period}
+
+ model_req_stat = self._collect_model_request_for_period(stat_start_timestamp)
+ online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now)
+ message_count_stat = self._collect_message_count_for_period(stat_start_timestamp)
+
+ # 统计数据合并
+ # 合并三类统计数据
+ for period_key, _ in stat_start_timestamp:
+ stat[period_key].update(model_req_stat[period_key])
+ stat[period_key].update(online_time_stat[period_key])
+ stat[period_key].update(message_count_stat[period_key])
+
+ if last_all_time_stat:
+ # 若存在上次完整统计数据,则将其与当前统计数据合并
+ for key, val in last_all_time_stat.items():
+ if isinstance(val, dict):
+ # 是字典类型,则进行合并
+ for sub_key, sub_val in val.items():
+ stat["all_time"][key][sub_key] += sub_val
+ else:
+ # 直接合并
+ stat["all_time"][key] += val
+
+ # 更新上次完整统计数据的时间戳
+ local_storage["last_full_statistics"] = {
+ "name_mapping": self.name_mapping,
+ "stat_data": stat["all_time"],
+ "timestamp": now.timestamp(),
+ }
+
+ return stat
# -- 以下为统计数据格式化方法 --
@@ -485,13 +439,15 @@ class StatisticOutputTask(AsyncTask):
"""
格式化总统计数据
"""
+
output = [
- f"总在线时间: {_format_online_time(stats.get(ONLINE_TIME, 0))}",
- f"总消息数: {stats.get(TOTAL_MSG_CNT, 0)}",
- f"总请求数: {stats.get(TOTAL_REQ_CNT, 0)}",
- f"总花费: {stats.get(TOTAL_COST, 0.0):.4f}¥",
+ f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}",
+ f"总消息数: {stats[TOTAL_MSG_CNT]}",
+ f"总请求数: {stats[TOTAL_REQ_CNT]}",
+ f"总花费: {stats[TOTAL_COST]:.4f}¥",
"",
]
+
return "\n".join(output)
@staticmethod
@@ -499,183 +455,174 @@ class StatisticOutputTask(AsyncTask):
"""
格式化按模型分类的统计数据
"""
- if stats.get(TOTAL_REQ_CNT, 0) > 0:
- data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
- output = [
- "按模型分类统计:",
- " 模型名称 调用次数 输入Token 输出Token Token总量 累计花费",
- ]
- req_cnt_by_model = stats.get(REQ_CNT_BY_MODEL, {})
- in_tok_by_model = stats.get(IN_TOK_BY_MODEL, defaultdict(int))
- out_tok_by_model = stats.get(OUT_TOK_BY_MODEL, defaultdict(int))
- total_tok_by_model = stats.get(TOTAL_TOK_BY_MODEL, defaultdict(int))
- cost_by_model = stats.get(COST_BY_MODEL, defaultdict(float))
-
- for model_name, count in sorted(req_cnt_by_model.items()):
- name = model_name[:29] + "..." if len(model_name) > 32 else model_name
- in_tokens = in_tok_by_model[model_name]
- out_tokens = out_tok_by_model[model_name]
- tokens = total_tok_by_model[model_name]
- cost = cost_by_model[model_name]
- output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost))
-
- output.append("")
- return "\n".join(output)
- else:
+ if stats[TOTAL_REQ_CNT] <= 0:
return ""
+ data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.4f}¥"
+
+ output = [
+ "按模型分类统计:",
+ " 模型名称 调用次数 输入Token 输出Token Token总量 累计花费",
+ ]
+ for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
+ name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
+ in_tokens = stats[IN_TOK_BY_MODEL][model_name]
+ out_tokens = stats[OUT_TOK_BY_MODEL][model_name]
+ tokens = stats[TOTAL_TOK_BY_MODEL][model_name]
+ cost = stats[COST_BY_MODEL][model_name]
+ output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost))
+
+ output.append("")
+ return "\n".join(output)
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
"""
格式化聊天统计数据
"""
- if stats.get(TOTAL_MSG_CNT, 0) > 0:
- output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
- msg_cnt_by_chat = stats.get(MSG_CNT_BY_CHAT, {})
- for chat_id, count in sorted(msg_cnt_by_chat.items()):
- chat_name_display = self.name_mapping.get(chat_id, (f"未知 ({chat_id})", None))[0]
- output.append(f"{chat_name_display[:32]:<32} {count:>10}")
-
- output.append("")
- return "\n".join(output)
- else:
+ if stats[TOTAL_MSG_CNT] <= 0:
return ""
+ output = ["聊天消息统计:", " 联系人/群组名称 消息数量"]
+ output.extend(
+ f"{self.name_mapping[chat_id][0][:32]:<32} {count:>10}"
+ for chat_id, count in sorted(stats[MSG_CNT_BY_CHAT].items())
+ )
+ output.append("")
+ return "\n".join(output)
- def _generate_html_report(self, stat_collection: dict[str, Any], now: datetime):
+ def _generate_html_report(self, stat: dict[str, Any], now: datetime):
"""
生成HTML格式的统计报告
- :param stat_collection: 包含所有时间段统计数据的字典 {period_key: stats_dict}
+ :param stat: 统计数据
:param now: 基准当前时间
+ :return: HTML格式的统计报告
"""
- # Correctly get deploy_time_dt for display purposes
- if "deploy_time" in local_storage:
- try:
- deploy_time_dt = datetime.fromtimestamp(local_storage["deploy_time"])
- except (TypeError, ValueError):
- logger.error("Invalid deploy_time in local_storage for HTML report. Using default.")
- deploy_time_dt = datetime(2000, 1, 1) # Fallback
- else:
- # This should ideally not happen if __init__ or _collect_all_statistics ran
- logger.warning("deploy_time not found in local_storage for HTML report. Using default.")
- deploy_time_dt = datetime(2000, 1, 1) # Fallback
- tab_list_html = []
- tab_content_html_list = []
+ tab_list = [
+ f'{period[2]} '
+ for period in self.stat_period
+ ]
- for (
- period_key,
- period_delta,
- period_display_name,
- ) in self.stat_period: # Use self.stat_period as defined by _collect_all_statistics
- tab_list_html.append(
- f'{period_display_name} '
+ def _format_stat_data(stat_data: dict[str, Any], div_id: str, start_time: datetime) -> str:
+ """
+ 格式化一个时间段的统计数据到html div块
+ :param stat_data: 统计数据
+ :param div_id: div的ID
+ :param start_time: 统计时间段开始时间
+ """
+ # format总在线时间
+
+ # 按模型分类统计
+ model_rows = "\n".join(
+ [
+ f"
"
+ f"{model_name} "
+ f"{count} "
+ f"{stat_data[IN_TOK_BY_MODEL][model_name]} "
+ f"{stat_data[OUT_TOK_BY_MODEL][model_name]} "
+ f"{stat_data[TOTAL_TOK_BY_MODEL][model_name]} "
+ f"{stat_data[COST_BY_MODEL][model_name]:.4f} ¥ "
+ f" "
+ for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
+ ]
)
-
- current_period_stats = stat_collection.get(period_key, {})
-
- if period_key == "all_time":
- start_time_dt_for_period = deploy_time_dt
- else:
- # Ensure period_delta is a timedelta object
- if isinstance(period_delta, timedelta):
- start_time_dt_for_period = now - period_delta
- else: # Fallback if period_delta is not as expected (e.g. from old self.stat_period)
- logger.warning(
- f"period_delta for {period_key} is not a timedelta. Using 'now'. Type: {type(period_delta)}"
- )
- start_time_dt_for_period = now
-
- html_content_for_tab = f"""
-
-
+ # 按请求类型分类统计
+ type_rows = "\n".join(
+ [
+ f"
"
+ f"{req_type} "
+ f"{count} "
+ f"{stat_data[IN_TOK_BY_TYPE][req_type]} "
+ f"{stat_data[OUT_TOK_BY_TYPE][req_type]} "
+ f"{stat_data[TOTAL_TOK_BY_TYPE][req_type]} "
+ f"{stat_data[COST_BY_TYPE][req_type]:.4f} ¥ "
+ f" "
+ for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
+ ]
+ )
+ # 按用户分类统计
+ user_rows = "\n".join(
+ [
+ f"
"
+ f"{user_id} "
+ f"{count} "
+ f"{stat_data[IN_TOK_BY_USER][user_id]} "
+ f"{stat_data[OUT_TOK_BY_USER][user_id]} "
+ f"{stat_data[TOTAL_TOK_BY_USER][user_id]} "
+ f"{stat_data[COST_BY_USER][user_id]:.4f} ¥ "
+ f" "
+ for user_id, count in sorted(stat_data[REQ_CNT_BY_USER].items())
+ ]
+ )
+ # 聊天消息统计
+ chat_rows = "\n".join(
+ [
+ f"
{self.name_mapping[chat_id][0]} {count} "
+ for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
+ ]
+ )
+ # 生成HTML
+ return f"""
+
+
统计时段:
- {start_time_dt_for_period.strftime("%Y-%m-%d %H:%M:%S")} ~ {now.strftime("%Y-%m-%d %H:%M:%S")}
+ {start_time.strftime("%Y-%m-%d %H:%M:%S")} ~ {now.strftime("%Y-%m-%d %H:%M:%S")}
-
总在线时间: {_format_online_time(current_period_stats.get(ONLINE_TIME, 0))}
-
总消息数: {current_period_stats.get(TOTAL_MSG_CNT, 0)}
-
总请求数: {current_period_stats.get(TOTAL_REQ_CNT, 0)}
-
总花费: {current_period_stats.get(TOTAL_COST, 0.0):.4f} ¥
+
总在线时间: {_format_online_time(stat_data[ONLINE_TIME])}
+
总消息数: {stat_data[TOTAL_MSG_CNT]}
+
总请求数: {stat_data[TOTAL_REQ_CNT]}
+
总花费: {stat_data[TOTAL_COST]:.4f} ¥
+
+
按模型分类统计
+
+ 模型名称 调用次数 输入Token 输出Token Token总量 累计花费
+
+ {model_rows}
+
+
+
+
按请求类型分类统计
+
+
+ 请求类型 调用次数 输入Token 输出Token Token总量 累计花费
+
+
+ {type_rows}
+
+
+
+
按用户分类统计
+
+
+ 用户名称 调用次数 输入Token 输出Token Token总量 累计花费
+
+
+ {user_rows}
+
+
+
+
聊天消息统计
+
+
+ 联系人/群组名称 消息数量
+
+
+ {chat_rows}
+
+
+
"""
- html_content_for_tab += "
按模型分类统计 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 "
- req_cnt_by_model = current_period_stats.get(REQ_CNT_BY_MODEL, {})
- in_tok_by_model = current_period_stats.get(IN_TOK_BY_MODEL, defaultdict(int))
- out_tok_by_model = current_period_stats.get(OUT_TOK_BY_MODEL, defaultdict(int))
- total_tok_by_model = current_period_stats.get(TOTAL_TOK_BY_MODEL, defaultdict(int))
- cost_by_model = current_period_stats.get(COST_BY_MODEL, defaultdict(float))
- if req_cnt_by_model:
- for model_name, count in sorted(req_cnt_by_model.items()):
- html_content_for_tab += (
- f""
- f"{model_name} "
- f"{count} "
- f"{in_tok_by_model[model_name]} "
- f"{out_tok_by_model[model_name]} "
- f"{total_tok_by_model[model_name]} "
- f"{cost_by_model[model_name]:.4f} ¥ "
- f" "
- )
- else:
- html_content_for_tab += "无数据 "
- html_content_for_tab += "
"
+ tab_content_list = [
+ _format_stat_data(stat[period[0]], period[0], now - period[1])
+ for period in self.stat_period
+ if period[0] != "all_time"
+ ]
- html_content_for_tab += "
按请求类型分类统计 请求类型 调用次数 输入Token 输出Token Token总量 累计花费 "
- req_cnt_by_type = current_period_stats.get(REQ_CNT_BY_TYPE, {})
- in_tok_by_type = current_period_stats.get(IN_TOK_BY_TYPE, defaultdict(int))
- out_tok_by_type = current_period_stats.get(OUT_TOK_BY_TYPE, defaultdict(int))
- total_tok_by_type = current_period_stats.get(TOTAL_TOK_BY_TYPE, defaultdict(int))
- cost_by_type = current_period_stats.get(COST_BY_TYPE, defaultdict(float))
- if req_cnt_by_type:
- for req_type, count in sorted(req_cnt_by_type.items()):
- html_content_for_tab += (
- f""
- f"{req_type} "
- f"{count} "
- f"{in_tok_by_type[req_type]} "
- f"{out_tok_by_type[req_type]} "
- f"{total_tok_by_type[req_type]} "
- f"{cost_by_type[req_type]:.4f} ¥ "
- f" "
- )
- else:
- html_content_for_tab += "无数据 "
- html_content_for_tab += "
"
+ tab_content_list.append(
+ _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
+ )
- html_content_for_tab += "
按用户分类统计 用户ID/名称 调用次数 输入Token 输出Token Token总量 累计花费 "
- req_cnt_by_user = current_period_stats.get(REQ_CNT_BY_USER, {})
- in_tok_by_user = current_period_stats.get(IN_TOK_BY_USER, defaultdict(int))
- out_tok_by_user = current_period_stats.get(OUT_TOK_BY_USER, defaultdict(int))
- total_tok_by_user = current_period_stats.get(TOTAL_TOK_BY_USER, defaultdict(int))
- cost_by_user = current_period_stats.get(COST_BY_USER, defaultdict(float))
- if req_cnt_by_user:
- for user_id, count in sorted(req_cnt_by_user.items()):
- user_display_name = self.name_mapping.get(user_id, (user_id, None))[0]
- html_content_for_tab += (
- f""
- f"{user_display_name} "
- f"{count} "
- f"{in_tok_by_user[user_id]} "
- f"{out_tok_by_user[user_id]} "
- f"{total_tok_by_user[user_id]} "
- f"{cost_by_user[user_id]:.4f} ¥ "
- f" "
- )
- else:
- html_content_for_tab += "无数据 "
- html_content_for_tab += "
"
-
- html_content_for_tab += (
- "
聊天消息统计 联系人/群组名称 消息数量 "
- )
- msg_cnt_by_chat = current_period_stats.get(MSG_CNT_BY_CHAT, {})
- if msg_cnt_by_chat:
- for chat_id, count in sorted(msg_cnt_by_chat.items()):
- chat_name_display = self.name_mapping.get(chat_id, (f"未知/归档聊天 ({chat_id})", None))[0]
- html_content_for_tab += f"{chat_name_display} {count} "
- else:
- html_content_for_tab += "无数据 "
- html_content_for_tab += "
"
-
- tab_content_html_list.append(html_content_for_tab)
+ joined_tab_list = "\n".join(tab_list)
+ joined_tab_content = "\n".join(tab_content_list)
html_template = (
"""
@@ -739,7 +686,6 @@ class StatisticOutputTask(AsyncTask):
border: 1px solid #ddd;
padding: 10px;
text-align: left;
- word-break: break-all;
}
th {
background-color: #3498db;
@@ -758,38 +704,24 @@ class StatisticOutputTask(AsyncTask):
.tabs {
overflow: hidden;
background: #ecf0f1;
- display: flex;
- flex-wrap: wrap;
- margin-bottom: -1px;
+ display: flex;
}
.tabs button {
- background: inherit;
- border: 1px solid #ccc;
- border-bottom: none;
- outline: none;
- padding: 14px 16px;
- cursor: pointer;
- transition: 0.3s;
- font-size: 16px;
- margin-right: 2px;
- border-radius: 4px 4px 0 0;
+ background: inherit; border: none; outline: none;
+ padding: 14px 16px; cursor: pointer;
+ transition: 0.3s; font-size: 16px;
}
.tabs button:hover {
background-color: #d4dbdc;
}
.tabs button.active {
- background-color: #fff;
- border-color: #ccc;
- border-bottom: 1px solid #fff;
- position: relative;
- z-index: 1;
+ background-color: #b3bbbd;
}
.tab-content {
display: none;
padding: 20px;
background-color: #fff;
border: 1px solid #ccc;
- border-top: none;
}
.tab-content.active {
display: block;
@@ -804,14 +736,10 @@ class StatisticOutputTask(AsyncTask):
统计截止时间: {now.strftime("%Y-%m-%d %H:%M:%S")}
- {"".join(tab_list_html)}
+ {joined_tab_list}
- {"".join(tab_content_html_list)}
-
-
+ {joined_tab_content}
"""
+ """
@@ -820,35 +748,20 @@ class StatisticOutputTask(AsyncTask):
tab_content = document.getElementsByClassName("tab-content");
tab_links = document.getElementsByClassName("tab-link");
- if (tab_content.length > 0 && tab_links.length > 0) {
- tab_content[0].classList.add("active");
- tab_links[0].classList.add("active");
- }
+ tab_content[0].classList.add("active");
+ tab_links[0].classList.add("active");
- function showTab(evt, tabName) {
- for (i = 0; i < tab_content.length; i++) {
- tab_content[i].classList.remove("active");
- }
- for (i = 0; i < tab_links.length; i++) {
- tab_links[i].classList.remove("active");
- }
- const currentTabContent = document.getElementById(tabName);
- if (currentTabContent) {
- currentTabContent.classList.add("active");
- }
- if (evt.currentTarget) {
- evt.currentTarget.classList.add("active");
- }
- }
+ function showTab(evt, tabName) {{
+ for (i = 0; i < tab_content.length; i++) tab_content[i].classList.remove("active");
+ for (i = 0; i < tab_links.length; i++) tab_links[i].classList.remove("active");
+ document.getElementById(tabName).classList.add("active");
+ evt.currentTarget.classList.add("active");
+ }}