Skip to content

Commit

Permalink
Use abatch
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jan 29, 2024
1 parent 2bdc7c7 commit 2cdf879
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions backend/app/agent_types/openai_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from langchain.schema.messages import FunctionMessage, ToolMessage
from langchain.schema.messages import ToolMessage
from langchain.tools import BaseTool
from langchain.tools.render import format_tool_to_openai_tool
from langchain_core.language_models.base import LanguageModelLike
Expand Down Expand Up @@ -39,7 +39,7 @@ def should_continue(messages):

# Define the function to execute tools
async def call_tool(messages):
tool_messages = []
actions: list[ToolInvocation] = []
# Based on the continue condition
# we know the last message involves a function call
last_message = messages[-1]
Expand All @@ -48,19 +48,25 @@ async def call_tool(messages):
function_name = function["name"]
_tool_input = json.loads(function["arguments"] or "{}")
# We construct an ToolInvocation from the function_call
action = ToolInvocation(
tool=function_name,
tool_input=_tool_input,
actions.append(
ToolInvocation(
tool=function_name,
tool_input=_tool_input,
)
)
# We call the tool_executor and get back a response
response = await tool_executor.ainvoke(action)
# We use the response to create a FunctionMessage
msg = ToolMessage(
# We call the tool_executor and get back a response
responses = await tool_executor.abatch(actions)
# We use the response to create a ToolMessage
tool_messages = [
ToolMessage(
tool_call_id=tool_call["id"],
content=json.dumps(response),
additional_kwargs={"name": function_name},
additional_kwargs={"name": tool_call["function"]["name"]},
)
tool_messages.append(msg)
for tool_call, response in zip(
last_message.additional_kwargs["tool_calls"], responses
)
]
return tool_messages

workflow = MessageGraph()
Expand Down

0 comments on commit 2cdf879

Please sign in to comment.