diff --git a/backend/app/agent_types/openai_agent.py b/backend/app/agent_types/openai_agent.py index 8ee9a645..c4bf05d0 100644 --- a/backend/app/agent_types/openai_agent.py +++ b/backend/app/agent_types/openai_agent.py @@ -1,8 +1,8 @@ import json -from langchain.schema.messages import FunctionMessage +from langchain.schema.messages import ToolMessage from langchain.tools import BaseTool -from langchain.tools.render import format_tool_to_openai_function +from langchain.tools.render import format_tool_to_openai_tool from langchain_core.language_models.base import LanguageModelLike from langchain_core.messages import SystemMessage from langgraph.checkpoint import BaseCheckpointSaver @@ -21,9 +21,7 @@ def _get_messages(messages): return [SystemMessage(content=system_message)] + messages if tools: - llm_with_tools = llm.bind( - functions=[format_tool_to_openai_function(t) for t in tools] - ) + llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools]) else: llm_with_tools = llm agent = _get_messages | llm_with_tools @@ -33,7 +31,7 @@ def _get_messages(messages): def should_continue(messages): last_message = messages[-1] # If there is no function call, then we finish - if "function_call" not in last_message.additional_kwargs: + if "tool_calls" not in last_message.additional_kwargs: return "end" # Otherwise if there is, we continue else: @@ -41,22 +39,35 @@ def should_continue(messages): # Define the function to execute tools async def call_tool(messages): + actions: list[ToolInvocation] = [] # Based on the continue condition # we know the last message involves a function call last_message = messages[-1] - # We construct an ToolInvocation from the function_call - action = ToolInvocation( - tool=last_message.additional_kwargs["function_call"]["name"], - tool_input=json.loads( - last_message.additional_kwargs["function_call"]["arguments"] - ), - ) + for tool_call in last_message.additional_kwargs["tool_calls"]: + function = tool_call["function"] + function_name = function["name"] + _tool_input = json.loads(function["arguments"] or "{}") + # We construct an ToolInvocation from the function_call + actions.append( + ToolInvocation( + tool=function_name, + tool_input=_tool_input, + ) + ) # We call the tool_executor and get back a response - response = await tool_executor.ainvoke(action) - # We use the response to create a FunctionMessage - function_message = FunctionMessage(content=str(response), name=action.tool) - # We return a list, because this will get added to the existing list - return function_message + responses = await tool_executor.abatch(actions) + # We use the response to create a ToolMessage + tool_messages = [ + ToolMessage( + tool_call_id=tool_call["id"], + content=json.dumps(response), + additional_kwargs={"name": tool_call["function"]["name"]}, + ) + for tool_call, response in zip( + last_message.additional_kwargs["tool_calls"], responses + ) + ] + return tool_messages workflow = MessageGraph() diff --git a/frontend/src/components/Message.tsx b/frontend/src/components/Message.tsx index 9686e9c2..2df9ec9d 100644 --- a/frontend/src/components/Message.tsx +++ b/frontend/src/components/Message.tsx @@ -52,7 +52,7 @@ function Function(props: { )} {props.args && ( -