Update tool_processor.py

pull/998/head
2829798842 2025-05-28 22:07:23 +08:00 committed by GitHub
parent 10a97f9bf1
commit d998202cc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 30 additions and 19 deletions

View File

@ -12,7 +12,7 @@ from .base_processor import BaseProcessor
from typing import List, Optional, Dict from typing import List, Optional, Dict
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.info.structured_info import StructuredInfo from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.chat.focus_chat.info.info_base import InfoBase from src.chat.heart_flow.observation.structure_observation import StructureObservation
logger = get_logger("processor") logger = get_logger("processor")
@ -40,7 +40,6 @@ If you need to use tools, please directly call the corresponding tool functions.
""" """
Prompt(tool_executor_prompt, "tool_executor_prompt") Prompt(tool_executor_prompt, "tool_executor_prompt")
class ToolProcessor(BaseProcessor): class ToolProcessor(BaseProcessor):
log_prefix = "工具执行器" log_prefix = "工具执行器"
@ -51,13 +50,13 @@ class ToolProcessor(BaseProcessor):
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
model=global_config.model.focus_tool_use, model=global_config.model.focus_tool_use,
max_tokens=500, max_tokens=500,
request_type="tool_execution", request_type="focus_tool",
) )
self.structured_info = [] self.structured_info = []
async def process_info( async def process_info(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
) -> List[InfoBase]: ) -> List[dict]:
"""处理信息对象 """处理信息对象
Args: Args:
@ -67,24 +66,36 @@ class ToolProcessor(BaseProcessor):
list: 处理后的结构化信息列表 list: 处理后的结构化信息列表
""" """
result_infos = [] working_infos = []
if observations: if observations:
for observation in observations: for observation in observations:
if isinstance(observation, ChattingObservation): if isinstance(observation, ChattingObservation):
# 执行工具调用 result, used_tools, prompt = await self.execute_tools(observation, running_memorys)
tool_results, used_tools, _ = await self.execute_tools(observation, running_memorys)
# 更新WorkingObservation中的结构化信息
# 为每个工具调用结果创建StructuredInfo对象并返回 logger.debug(f"工具调用结果: {result}")
for tool_result in tool_results:
structured_info = StructuredInfo() for observation in observations:
if isinstance(observation, StructureObservation):
structured_info.data[tool_result.get("type")] = tool_result.get("content") for structured_info in result:
result_infos.append(structured_info) # logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}")
logger.info(f"{self.log_prefix} 工具调用成功: {tool_result.get('type')} - {tool_result.get('content')}") observation.add_structured_info(structured_info)
logger.debug(f"result_infos: {result_infos}") working_infos = observation.get_observe_info()
return result_infos logger.debug(f"{self.log_prefix} 获取更新后WorkingObservation中的结构化信息: {working_infos}")
structured_info = StructuredInfo()
if working_infos:
for working_info in working_infos:
# print(f"working_info: {working_info}")
# print(f"working_info.get('type'): {working_info.get('type')}")
# print(f"working_info.get('content'): {working_info.get('content')}")
structured_info.set_info(key=working_info.get("type"), value=working_info.get("content"))
# info = structured_info.get_processed_info()
# print(f"info: {info}")
return [structured_info]
async def execute_tools(self, observation: ChattingObservation, running_memorys: Optional[List[Dict]] = None): async def execute_tools(self, observation: ChattingObservation, running_memorys: Optional[List[Dict]] = None):
""" """
@ -150,7 +161,7 @@ class ToolProcessor(BaseProcessor):
) )
# 调用LLM专注于工具使用 # 调用LLM专注于工具使用
# logger.debug(f"开始执行工具调用{prompt}") logger.debug(f"开始执行工具调用{prompt}")
response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools) response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
logger.debug(f"获取到工具原始输出:\n{tool_calls}") logger.debug(f"获取到工具原始输出:\n{tool_calls}")