From e816a4ab4c524f3614a4fda8929ff72451d083da Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Thu, 13 Nov 2025 19:03:51 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9Aruff?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory_system/memory_retrieval.py | 20 ++++++++++++------- .../retrieval_tools/tool_registry.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index ab4e0f5b..173ad197 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -10,7 +10,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.plugin_system.apis import llm_api from src.common.database.database_model import ThinkingBack from json_repair import repair_json -from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools, register_memory_retrieval_tool +from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message logger = get_logger("memory_retrieval") @@ -384,17 +384,23 @@ async def _react_agent_solve_question( remaining_iterations=remaining_iterations, ) - def message_factory(_client) -> List[Message]: + def message_factory( + _client, + *, + _head_prompt: str = head_prompt, + _prompt: str = prompt, + _conversation_messages: List[Message] = conversation_messages, + ) -> List[Message]: messages: List[Message] = [] system_builder = MessageBuilder() system_builder.set_role(RoleType.System) - system_builder.add_text_content(head_prompt) - if prompt.strip(): - system_builder.add_text_content(f"\n{prompt}") + system_builder.add_text_content(_head_prompt) + if _prompt.strip(): + system_builder.add_text_content(f"\n{_prompt}") messages.append(system_builder.build()) - messages.extend(conversation_messages) + messages.extend(_conversation_messages) for msg in messages: print(msg) @@ -605,7 +611,7 @@ async def _react_agent_solve_question( observations = await asyncio.gather(*tool_tasks, return_exceptions=True) # 处理执行结果 - for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations)): + for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)): if isinstance(observation, Exception): observation = f"工具执行异常: {str(observation)}" logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行异常: {observation}") diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py index 1bf889ec..143666ab 100644 --- a/src/memory_system/retrieval_tools/tool_registry.py +++ b/src/memory_system/retrieval_tools/tool_registry.py @@ -3,7 +3,7 @@ 提供统一的工具注册和管理接口 """ -from typing import List, Dict, Any, Optional, Callable, Awaitable, Tuple +from typing import List, Dict, Any, Optional, Callable, Awaitable from src.common.logger import get_logger from src.llm_models.payload_content.tool_option import ToolParamType