mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
commit
a1e9893ac1
|
|
@ -28,7 +28,7 @@ version = "1.1.1"
|
||||||
```toml
|
```toml
|
||||||
[[api_providers]]
|
[[api_providers]]
|
||||||
name = "DeepSeek" # 服务商名称(自定义)
|
name = "DeepSeek" # 服务商名称(自定义)
|
||||||
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
|
base_url = "https://api.deepseek.com/v1" # API服务的基础URL
|
||||||
api_key = "your-api-key-here" # API密钥
|
api_key = "your-api-key-here" # API密钥
|
||||||
client_type = "openai" # 客户端类型
|
client_type = "openai" # 客户端类型
|
||||||
max_retry = 2 # 最大重试次数
|
max_retry = 2 # 最大重试次数
|
||||||
|
|
@ -43,19 +43,19 @@ retry_interval = 10 # 重试间隔(秒)
|
||||||
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
|
||||||
| `base_url` | ✅ | API服务的基础URL | - |
|
| `base_url` | ✅ | API服务的基础URL | - |
|
||||||
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
|
||||||
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` |
|
| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式) | `openai` |
|
||||||
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
|
||||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||||
|
|
||||||
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
|
**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。**
|
||||||
### 2.3 支持的服务商示例
|
### 2.3 支持的服务商示例
|
||||||
|
|
||||||
#### DeepSeek
|
#### DeepSeek
|
||||||
```toml
|
```toml
|
||||||
[[api_providers]]
|
[[api_providers]]
|
||||||
name = "DeepSeek"
|
name = "DeepSeek"
|
||||||
base_url = "https://api.deepseek.cn/v1"
|
base_url = "https://api.deepseek.com/v1"
|
||||||
api_key = "your-deepseek-api-key"
|
api_key = "your-deepseek-api-key"
|
||||||
client_type = "openai"
|
client_type = "openai"
|
||||||
```
|
```
|
||||||
|
|
@ -73,7 +73,7 @@ client_type = "openai"
|
||||||
```toml
|
```toml
|
||||||
[[api_providers]]
|
[[api_providers]]
|
||||||
name = "Google"
|
name = "Google"
|
||||||
base_url = "https://api.google.com/v1"
|
base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
api_key = "your-google-api-key"
|
api_key = "your-google-api-key"
|
||||||
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
client_type = "gemini" # 注意:Gemini需要使用特殊客户端
|
||||||
```
|
```
|
||||||
|
|
@ -131,9 +131,20 @@ enable_thinking = false # 禁用思考
|
||||||
[models.extra_params]
|
[models.extra_params]
|
||||||
thinking = {type = "disabled"} # 禁用思考
|
thinking = {type = "disabled"} # 禁用思考
|
||||||
```
|
```
|
||||||
|
|
||||||
|
而对于`gemini`需要单独进行配置
|
||||||
|
```toml
|
||||||
|
[[models]]
|
||||||
|
model_identifier = "gemini-2.5-flash"
|
||||||
|
name = "gemini-2.5-flash"
|
||||||
|
api_provider = "Google"
|
||||||
|
[models.extra_params]
|
||||||
|
thinking_budget = 0 # 禁用思考
|
||||||
|
# thinking_budget = -1 由模型自己决定
|
||||||
|
```
|
||||||
|
|
||||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||||
|
|
||||||
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
|
|
||||||
### 3.3 配置参数说明
|
### 3.3 配置参数说明
|
||||||
|
|
||||||
| 参数 | 必填 | 说明 |
|
| 参数 | 必填 | 说明 |
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from maim_message import UserInfo, Seg
|
from maim_message import UserInfo, Seg, GroupInfo
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
@ -27,7 +27,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
|
||||||
|
|
||||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||||
"""检查消息是否包含过滤词
|
"""检查消息是否包含过滤词
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -40,14 +40,14 @@ def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||||
"""
|
"""
|
||||||
for word in global_config.message_receive.ban_words:
|
for word in global_config.message_receive.ban_words:
|
||||||
if word in text:
|
if word in text:
|
||||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
chat_name = group_info.group_name if group_info else "私聊"
|
||||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||||
"""检查消息是否匹配过滤正则表达式
|
"""检查消息是否匹配过滤正则表达式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -61,10 +61,10 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||||
# 检查text是否为None或空字符串
|
# 检查text是否为None或空字符串
|
||||||
if text is None or not text:
|
if text is None or not text:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||||
if re.search(pattern, text):
|
if re.search(pattern, text):
|
||||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
chat_name = group_info.group_name if group_info else "私聊"
|
||||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||||
return True
|
return True
|
||||||
|
|
@ -251,6 +251,18 @@ class ChatBot:
|
||||||
# return
|
# return
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 过滤检查
|
||||||
|
if _check_ban_words(
|
||||||
|
message.processed_plain_text,
|
||||||
|
user_info, # type: ignore
|
||||||
|
group_info,
|
||||||
|
) or _check_ban_regex(
|
||||||
|
message.raw_message, # type: ignore
|
||||||
|
user_info, # type: ignore
|
||||||
|
group_info,
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
|
|
@ -268,14 +280,6 @@ class ChatBot:
|
||||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||||
# return
|
# return
|
||||||
|
|
||||||
# 过滤检查
|
|
||||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
|
||||||
message.raw_message, # type: ignore
|
|
||||||
chat,
|
|
||||||
user_info, # type: ignore
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# 命令处理 - 使用新插件系统检查并处理命令
|
# 命令处理 - 使用新插件系统检查并处理命令
|
||||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai.types import (
|
from google.genai.types import (
|
||||||
|
|
@ -17,6 +17,7 @@ from google.genai.types import (
|
||||||
EmbedContentResponse,
|
EmbedContentResponse,
|
||||||
EmbedContentConfig,
|
EmbedContentConfig,
|
||||||
SafetySetting,
|
SafetySetting,
|
||||||
|
HttpOptions,
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
HarmBlockThreshold,
|
HarmBlockThreshold,
|
||||||
)
|
)
|
||||||
|
|
@ -345,22 +346,27 @@ class GeminiClient(BaseClient):
|
||||||
|
|
||||||
def __init__(self, api_provider: APIProvider):
|
def __init__(self, api_provider: APIProvider):
|
||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
|
|
||||||
|
# 增加传入参数处理
|
||||||
|
http_options_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# 秒转换为毫秒传入
|
||||||
|
if api_provider.timeout is not None:
|
||||||
|
http_options_kwargs["timeout"] = int(api_provider.timeout * 1000)
|
||||||
|
|
||||||
|
# 传入并处理地址和版本(必须为Gemini格式)
|
||||||
|
if api_provider.base_url:
|
||||||
|
parts = api_provider.base_url.rstrip("/").rsplit("/", 1)
|
||||||
|
if len(parts) == 2 and parts[1].startswith("v"):
|
||||||
|
http_options_kwargs["base_url"] = f"{parts[0]}/"
|
||||||
|
http_options_kwargs["api_version"] = parts[1]
|
||||||
|
else:
|
||||||
|
http_options_kwargs["base_url"] = api_provider.base_url
|
||||||
self.client = genai.Client(
|
self.client = genai.Client(
|
||||||
|
http_options=HttpOptions(**http_options_kwargs),
|
||||||
api_key=api_provider.api_key,
|
api_key=api_provider.api_key,
|
||||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||||
|
|
||||||
# 尝试传入自定义base_url(实验性,必须为Gemini格式)
|
|
||||||
if hasattr(api_provider, "base_url") and api_provider.base_url:
|
|
||||||
base_url = api_provider.base_url.rstrip("/") # 去掉末尾 /
|
|
||||||
self.client._api_client._http_options.base_url = base_url
|
|
||||||
|
|
||||||
# 如果 base_url 已经带了 /v1 或 /v1beta,就清掉 SDK 的 api_version
|
|
||||||
if base_url.endswith("/v1") or base_url.endswith("/v1beta"):
|
|
||||||
self.client._api_client._http_options.api_version = None
|
|
||||||
|
|
||||||
# 让 GeminiClient 内部也能访问底层 api_client
|
|
||||||
self._api_client = self.client._api_client
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clamp_thinking_budget(tb: int, model_id: str) -> int:
|
def clamp_thinking_budget(tb: int, model_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -380,20 +386,29 @@ class GeminiClient(BaseClient):
|
||||||
limits = THINKING_BUDGET_LIMITS[key]
|
limits = THINKING_BUDGET_LIMITS[key]
|
||||||
break
|
break
|
||||||
|
|
||||||
# 特殊值处理
|
# 预算值处理
|
||||||
if tb == THINKING_BUDGET_AUTO:
|
if tb == THINKING_BUDGET_AUTO:
|
||||||
return THINKING_BUDGET_AUTO
|
return THINKING_BUDGET_AUTO
|
||||||
if tb == THINKING_BUDGET_DISABLED:
|
if tb == THINKING_BUDGET_DISABLED:
|
||||||
if limits and limits.get("can_disable", False):
|
if limits and limits.get("can_disable", False):
|
||||||
return THINKING_BUDGET_DISABLED
|
return THINKING_BUDGET_DISABLED
|
||||||
return limits["min"] if limits else THINKING_BUDGET_AUTO
|
if limits:
|
||||||
|
logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}")
|
||||||
|
return limits["min"]
|
||||||
|
return THINKING_BUDGET_AUTO
|
||||||
|
|
||||||
# 已知模型裁剪到范围
|
# 已知模型范围裁剪 + 提示
|
||||||
if limits:
|
if limits:
|
||||||
return max(limits["min"], min(tb, limits["max"]))
|
if tb < limits["min"]:
|
||||||
|
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}")
|
||||||
|
return limits["min"]
|
||||||
|
if tb > limits["max"]:
|
||||||
|
logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}")
|
||||||
|
return limits["max"]
|
||||||
|
return tb
|
||||||
|
|
||||||
# 未知模型,返回动态模式
|
# 未知模型 → 默认自动模式
|
||||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
|
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容")
|
||||||
return THINKING_BUDGET_AUTO
|
return THINKING_BUDGET_AUTO
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
|
|
@ -448,7 +463,9 @@ class GeminiClient(BaseClient):
|
||||||
try:
|
try:
|
||||||
tb = int(extra_params["thinking_budget"])
|
tb = int(extra_params["thinking_budget"])
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
|
logger.warning(
|
||||||
|
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}"
|
||||||
|
)
|
||||||
# 裁剪到模型支持的范围
|
# 裁剪到模型支持的范围
|
||||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,18 +26,6 @@ install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("model_utils")
|
logger = get_logger("model_utils")
|
||||||
|
|
||||||
# 常见Error Code Mapping
|
|
||||||
error_code_mapping = {
|
|
||||||
400: "参数不正确",
|
|
||||||
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
|
|
||||||
402: "账号余额不足",
|
|
||||||
403: "需要实名,或余额不足",
|
|
||||||
404: "Not Found",
|
|
||||||
429: "请求过于频繁,请稍后再试",
|
|
||||||
500: "服务器内部故障",
|
|
||||||
503: "服务器负载过高",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class RequestType(Enum):
|
class RequestType(Enum):
|
||||||
"""请求类型枚举"""
|
"""请求类型枚举"""
|
||||||
|
|
@ -267,14 +255,14 @@ class LLMRequest:
|
||||||
extra_params=model_info.extra_params,
|
extra_params=model_info.extra_params,
|
||||||
)
|
)
|
||||||
elif request_type == RequestType.EMBEDDING:
|
elif request_type == RequestType.EMBEDDING:
|
||||||
assert embedding_input is not None
|
assert embedding_input is not None, "嵌入输入不能为空"
|
||||||
return await client.get_embedding(
|
return await client.get_embedding(
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
embedding_input=embedding_input,
|
embedding_input=embedding_input,
|
||||||
extra_params=model_info.extra_params,
|
extra_params=model_info.extra_params,
|
||||||
)
|
)
|
||||||
elif request_type == RequestType.AUDIO:
|
elif request_type == RequestType.AUDIO:
|
||||||
assert audio_base64 is not None
|
assert audio_base64 is not None, "音频Base64不能为空"
|
||||||
return await client.get_audio_transcriptions(
|
return await client.get_audio_transcriptions(
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
audio_base64=audio_base64,
|
audio_base64=audio_base64,
|
||||||
|
|
@ -365,24 +353,23 @@ class LLMRequest:
|
||||||
embedding_input=embedding_input,
|
embedding_input=embedding_input,
|
||||||
audio_base64=audio_base64,
|
audio_base64=audio_base64,
|
||||||
)
|
)
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
if response_usage := response.usage:
|
||||||
|
total_tokens += response_usage.total_tokens
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
||||||
return response, model_info
|
return response, model_info
|
||||||
|
|
||||||
except ModelAttemptFailed as e:
|
except ModelAttemptFailed as e:
|
||||||
last_exception = e.original_exception or e
|
last_exception = e.original_exception or e
|
||||||
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
|
logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}")
|
||||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
|
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty - 1)
|
||||||
failed_models_this_request.add(model_info.name)
|
failed_models_this_request.add(model_info.name)
|
||||||
|
|
||||||
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
|
if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400:
|
||||||
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
|
logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。")
|
||||||
raise last_exception from e
|
raise last_exception from e
|
||||||
|
|
||||||
finally:
|
|
||||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
|
||||||
if usage_penalty > 0:
|
|
||||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
|
|
||||||
|
|
||||||
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
|
logger.error(f"所有 {max_attempts} 个模型均尝试失败。")
|
||||||
if last_exception:
|
if last_exception:
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue