diff --git a/docs/model_configuration_guide.md b/docs/model_configuration_guide.md index fd1cb018..f2da8be1 100644 --- a/docs/model_configuration_guide.md +++ b/docs/model_configuration_guide.md @@ -28,7 +28,7 @@ version = "1.1.1" ```toml [[api_providers]] name = "DeepSeek" # 服务商名称(自定义) -base_url = "https://api.deepseek.cn/v1" # API服务的基础URL +base_url = "https://api.deepseek.com/v1" # API服务的基础URL api_key = "your-api-key-here" # API密钥 client_type = "openai" # 客户端类型 max_retry = 2 # 最大重试次数 @@ -43,19 +43,19 @@ retry_interval = 10 # 重试间隔(秒) | `name` | ✅ | 服务商名称,需要在模型配置中引用 | - | | `base_url` | ✅ | API服务的基础URL | - | | `api_key` | ✅ | API密钥,请替换为实际密钥 | - | -| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` | +| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式) | `openai` | | `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 | | `timeout` | ❌ | API请求超时时间(秒) | 30 | | `retry_interval` | ❌ | 重试间隔时间(秒) | 10 | -**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。** +**请注意,对于`client_type`为`gemini`的模型,`retry`字段由`gemini`自己决定。** ### 2.3 支持的服务商示例 #### DeepSeek ```toml [[api_providers]] name = "DeepSeek" -base_url = "https://api.deepseek.cn/v1" +base_url = "https://api.deepseek.com/v1" api_key = "your-deepseek-api-key" client_type = "openai" ``` @@ -73,7 +73,7 @@ client_type = "openai" ```toml [[api_providers]] name = "Google" -base_url = "https://api.google.com/v1" +base_url = "https://generativelanguage.googleapis.com/v1beta" api_key = "your-google-api-key" client_type = "gemini" # 注意:Gemini需要使用特殊客户端 ``` @@ -131,9 +131,20 @@ enable_thinking = false # 禁用思考 [models.extra_params] thinking = {type = "disabled"} # 禁用思考 ``` + +而对于`gemini`需要单独进行配置 +```toml +[[models]] +model_identifier = "gemini-2.5-flash" +name = "gemini-2.5-flash" +api_provider = "Google" +[models.extra_params] +thinking_budget = 0 # 禁用思考 +# thinking_budget = -1 由模型自己决定 +``` + 请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。 -**请注意,对于`client_type`为`gemini`的模型,此字段无效。** ### 3.3 配置参数说明 | 参数 | 必填 | 说明 | diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 0709dcd8..f611bbd2 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -3,7 +3,7 @@ import os import re from typing import Dict, Any, Optional -from maim_message import UserInfo, Seg +from maim_message import UserInfo, Seg, GroupInfo from src.common.logger import get_logger from src.config.config import global_config @@ -27,7 +27,7 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.. logger = get_logger("chat") -def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: +def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool: """检查消息是否包含过滤词 Args: @@ -40,14 +40,14 @@ def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: """ for word in global_config.message_receive.ban_words: if word in text: - chat_name = chat.group_info.group_name if chat.group_info else "私聊" + chat_name = group_info.group_name if group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") logger.info(f"[过滤词识别]消息中含有{word},filtered") return True return False -def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: +def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool: """检查消息是否匹配过滤正则表达式 Args: @@ -61,10 +61,10 @@ def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: # 检查text是否为None或空字符串 if text is None or not text: return False - + for pattern in global_config.message_receive.ban_msgs_regex: if re.search(pattern, text): - chat_name = chat.group_info.group_name if chat.group_info else "私聊" + chat_name = group_info.group_name if group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered") return True @@ -251,6 +251,18 @@ class ChatBot: # return pass + # 过滤检查 + if _check_ban_words( + message.processed_plain_text, + user_info, # type: ignore + group_info, + ) or _check_ban_regex( + message.raw_message, # type: ignore + user_info, # type: ignore + group_info, + ): + return + get_chat_manager().register_message(message) chat = await get_chat_manager().get_or_create_stream( @@ -268,14 +280,6 @@ class ChatBot: # logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}") # return - # 过滤检查 - if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore - message.raw_message, # type: ignore - chat, - user_info, # type: ignore - ): - return - # 命令处理 - 使用新插件系统检查并处理命令 is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 67c7475e..c23ba90a 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,7 +1,7 @@ import asyncio import io import base64 -from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List +from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict from google import genai from google.genai.types import ( @@ -17,6 +17,7 @@ from google.genai.types import ( EmbedContentResponse, EmbedContentConfig, SafetySetting, + HttpOptions, HarmCategory, HarmBlockThreshold, ) @@ -345,22 +346,27 @@ class GeminiClient(BaseClient): def __init__(self, api_provider: APIProvider): super().__init__(api_provider) + + # 增加传入参数处理 + http_options_kwargs: Dict[str, Any] = {} + + # 秒转换为毫秒传入 + if api_provider.timeout is not None: + http_options_kwargs["timeout"] = int(api_provider.timeout * 1000) + + # 传入并处理地址和版本(必须为Gemini格式) + if api_provider.base_url: + parts = api_provider.base_url.rstrip("/").rsplit("/", 1) + if len(parts) == 2 and parts[1].startswith("v"): + http_options_kwargs["base_url"] = f"{parts[0]}/" + http_options_kwargs["api_version"] = parts[1] + else: + http_options_kwargs["base_url"] = api_provider.base_url self.client = genai.Client( + http_options=HttpOptions(**http_options_kwargs), api_key=api_provider.api_key, ) # 这里和openai不一样,gemini会自己决定自己是否需要retry - # 尝试传入自定义base_url(实验性,必须为Gemini格式) - if hasattr(api_provider, "base_url") and api_provider.base_url: - base_url = api_provider.base_url.rstrip("/") # 去掉末尾 / - self.client._api_client._http_options.base_url = base_url - - # 如果 base_url 已经带了 /v1 或 /v1beta,就清掉 SDK 的 api_version - if base_url.endswith("/v1") or base_url.endswith("/v1beta"): - self.client._api_client._http_options.api_version = None - - # 让 GeminiClient 内部也能访问底层 api_client - self._api_client = self.client._api_client - @staticmethod def clamp_thinking_budget(tb: int, model_id: str) -> int: """ @@ -380,20 +386,29 @@ class GeminiClient(BaseClient): limits = THINKING_BUDGET_LIMITS[key] break - # 特殊值处理 + # 预算值处理 if tb == THINKING_BUDGET_AUTO: return THINKING_BUDGET_AUTO if tb == THINKING_BUDGET_DISABLED: if limits and limits.get("can_disable", False): return THINKING_BUDGET_DISABLED - return limits["min"] if limits else THINKING_BUDGET_AUTO + if limits: + logger.warning(f"模型 {model_id} 不支持禁用思考预算,已回退到最小值 {limits['min']}") + return limits["min"] + return THINKING_BUDGET_AUTO - # 已知模型裁剪到范围 + # 已知模型范围裁剪 + 提示 if limits: - return max(limits["min"], min(tb, limits["max"])) + if tb < limits["min"]: + logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过小,已调整为最小值 {limits['min']}") + return limits["min"] + if tb > limits["max"]: + logger.warning(f"模型 {model_id} 的 thinking_budget={tb} 过大,已调整为最大值 {limits['max']}") + return limits["max"] + return tb - # 未知模型,返回动态模式 - logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。") + # 未知模型 → 默认自动模式 + logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,已启用模型自动预算兼容") return THINKING_BUDGET_AUTO async def get_response( @@ -448,7 +463,9 @@ class GeminiClient(BaseClient): try: tb = int(extra_params["thinking_budget"]) except (ValueError, TypeError): - logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}") + logger.warning( + f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用模型自动预算模式 {tb}" + ) # 裁剪到模型支持的范围 tb = self.clamp_thinking_budget(tb, model_info.model_identifier) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 8bb35ef0..b7a65842 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -26,18 +26,6 @@ install(extra_lines=3) logger = get_logger("model_utils") -# 常见Error Code Mapping -error_code_mapping = { - 400: "参数不正确", - 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", - 402: "账号余额不足", - 403: "需要实名,或余额不足", - 404: "Not Found", - 429: "请求过于频繁,请稍后再试", - 500: "服务器内部故障", - 503: "服务器负载过高", -} - class RequestType(Enum): """请求类型枚举""" @@ -267,14 +255,14 @@ class LLMRequest: extra_params=model_info.extra_params, ) elif request_type == RequestType.EMBEDDING: - assert embedding_input is not None + assert embedding_input is not None, "嵌入输入不能为空" return await client.get_embedding( model_info=model_info, embedding_input=embedding_input, extra_params=model_info.extra_params, ) elif request_type == RequestType.AUDIO: - assert audio_base64 is not None + assert audio_base64 is not None, "音频Base64不能为空" return await client.get_audio_transcriptions( model_info=model_info, audio_base64=audio_base64, @@ -365,24 +353,23 @@ class LLMRequest: embedding_input=embedding_input, audio_base64=audio_base64, ) + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + if response_usage := response.usage: + total_tokens += response_usage.total_tokens + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) return response, model_info except ModelAttemptFailed as e: last_exception = e.original_exception or e logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}") total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) + self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty - 1) failed_models_this_request.add(model_info.name) if isinstance(last_exception, RespNotOkException) and last_exception.status_code == 400: logger.error("收到不可恢复的客户端错误 (400),中止所有尝试。") raise last_exception from e - finally: - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - if usage_penalty > 0: - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) - logger.error(f"所有 {max_attempts} 个模型均尝试失败。") if last_exception: raise last_exception