pull/1258/head
SengokuCola 2025-09-24 10:19:26 +08:00
commit a1e9893ac1
4 changed files with 79 additions and 60 deletions

View File

@ -28,7 +28,7 @@ version = "1.1.1"
```toml
[[api_providers]]
name = "DeepSeek" # 服务商名称(自定义)
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
api_key = "your-api-key-here" # API密钥
client_type = "openai" # 客户端类型
max_retry = 2 # 最大重试次数
@ -43,19 +43,19 @@ retry_interval = 10 # 重试间隔(秒)
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式,现在支持不良好 | `openai` |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式 | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
### 2.3 支持的服务商示例
#### DeepSeek
```toml
[[api_providers]]
name = "DeepSeek"
base_url = "https://api.deepseek.cn/v1"
base_url = "https://api.deepseek.com/v1"
api_key = "your-deepseek-api-key"
client_type = "openai"
```
@ -73,7 +73,7 @@ client_type = "openai"
```toml
[[api_providers]]
name = "Google"
base_url = "https://api.google.com/v1"
base_url = "https://generativelanguage.googleapis.com/v1beta"
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
```
@ -131,9 +131,20 @@ enable_thinking = false # 禁用思考
[models.extra_params]
thinking = {type = "disabled"} # 禁用思考
```
而对于`gemini`需要单独进行配置
```toml
[[models]]
model_identifier = "gemini-2.5-flash"
name = "gemini-2.5-flash"
api_provider = "Google"
[models.extra_params]
thinking_budget = 0 # 禁用思考
# thinking_budget = -1 由模型自己决定
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |

View File

@ -3,7 +3,7 @@ import os
import re
from typing import Dict, Any, Optional
from maim_message import UserInfo, Seg
from maim_message import UserInfo, Seg, GroupInfo
from src.common.logger import get_logger
from src.config.config import global_config
@ -27,7 +27,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
logger = get_logger("chat")
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
"""检查消息是否包含过滤词
Args:
@ -40,14 +40,14 @@ def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""
for word in global_config.message_receive.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
chat_name = group_info.group_name if group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True
return False
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
"""检查消息是否匹配过滤正则表达式
Args:
@ -61,10 +61,10 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
# 检查text是否为None或空字符串
if text is None or not text:
return False
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
chat_name = group_info.group_name if group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
@ -251,6 +251,18 @@ class ChatBot:
# return
pass
# 过滤检查
if _check_ban_words(
message.processed_plain_text,
user_info, # type: ignore
group_info,
) or _check_ban_regex(
message.raw_message, # type: ignore
user_info, # type: ignore
group_info,
):
return
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
@ -268,14 +280,6 @@ class ChatBot:
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
# return
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
chat,
user_info, # type: ignore
):
return
# 命令处理 - 使用新插件系统检查并处理命令
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)

View File

@ -1,7 +1,7 @@
import asyncio
import io
import base64
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict
from google import genai
from google.genai.types import (
@ -17,6 +17,7 @@ from google.genai.types import (
EmbedContentResponse,
EmbedContentConfig,
SafetySetting,
HttpOptions,
HarmCategory,
HarmBlockThreshold,
)
@ -345,22 +346,27 @@ class GeminiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 增加传入参数处理
http_options_kwargs: Dict[str, Any] = {}
# 秒转换为毫秒传入
if api_provider.timeout is not None:
http_options_kwargs["timeout"] = int(api_provider.timeout * 1000)
# 传入并处理地址和版本(必须为Gemini格式)
if api_provider.base_url:
parts = api_provider.base_url.rstrip("/").rsplit("/", 1)
if len(parts) == 2 and parts[1].startswith("v"):
http_options_kwargs["base_url"] = f"{parts[0]}/"
http_options_kwargs["api_version"] = parts[1]
else:
http_options_kwargs["base_url"] = api_provider.base_url
self.client = genai.Client(
http_options=HttpOptions(**http_options_kwargs),
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
# 尝试传入自定义base_url(实验性必须为Gemini格式)
if hasattr(api_provider, "base_url") and api_provider.base_url:
base_url = api_provider.base_url.rstrip("/") # 去掉末尾 /
self.client._api_client._http_options.base_url = base_url
# 如果 base_url 已经带了 /v1 或 /v1beta就清掉 SDK 的 api_version
if base_url.endswith("/v1") or base_url.endswith("/v1beta"):
self.client._api_client._http_options.api_version = None
# 让 GeminiClient 内部也能访问底层 api_client
self._api_client = self.client._api_client
@staticmethod
def clamp_thinking_budget(tb: int, model_id: str) -> int:
"""
@ -380,20 +386,29 @@ class GeminiClient(BaseClient):
limits = THINKING_BUDGET_LIMITS[key]
break
# 特殊值处理
# 预算值处理
if tb == THINKING_BUDGET_AUTO:
return THINKING_BUDGET_AUTO
if tb == THINKING_BUDGET_DISABLED:
if limits and limits.get("can_disable", False):
return THINKING_BUDGET_DISABLED
return limits["min"] if limits else THINKING_BUDGET_AUTO
if limits:
logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}")
return limits["min"]
return THINKING_BUDGET_AUTO
# 已知模型裁剪到范围
# 已知模型范围裁剪 + 提示
if limits:
return max(limits["min"], min(tb, limits["max"]))
if tb < limits["min"]:
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}")
return limits["min"]
if tb > limits["max"]:
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}")
return limits["max"]
return tb
# 未知模型,返回动态模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
# 未知模型 → 默认自动模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容")
return THINKING_BUDGET_AUTO
async def get_response(
@ -448,7 +463,9 @@ class GeminiClient(BaseClient):
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
logger.warning(
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
)
# 裁剪到模型支持的范围
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)

View File

@ -26,18 +26,6 @@ install(extra_lines=3)
logger = get_logger("model_utils")
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class RequestType(Enum):
"""请求类型枚举"""
@ -267,14 +255,14 @@ class LLMRequest:
extra_params=model_info.extra_params,
)
elif request_type == RequestType.EMBEDDING:
assert embedding_input is not None
assert embedding_input is not None, "嵌入输入不能为空"
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.AUDIO:
assert audio_base64 is not None
assert audio_base64 is not None, "音频Base64不能为空"
return await client.get_audio_transcriptions(
model_info=model_info,
audio_base64=audio_base64,
@ -365,24 +353,23 @@ class LLMRequest:
embedding_input=embedding_input,
audio_base64=audio_base64,
)
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
if response_usage := response.usage:
total_tokens += response_usage.total_tokens
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
return response, model_info
except ModelAttemptFailed as e:
last_exception = e.original_exception or e
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty - 1)
failed_models_this_request.add(model_info.name)
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
raise last_exception from e
finally:
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
if usage_penalty > 0:
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
if last_exception:
raise last_exception