Skip to content

Commit

Permalink
Merge pull request #304 from langchain-ai/nc/15apr/tool-calls
Browse files Browse the repository at this point in the history
Upgrade to new tool calling api
  • Loading branch information
nfcampos authored Apr 15, 2024
2 parents 950719c + 5e5e92b commit cb39b9b
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 263 deletions.
15 changes: 7 additions & 8 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
)
from langgraph.checkpoint import CheckpointAt

from app.agent_types.google_agent import get_google_agent_executor
from app.agent_types.openai_agent import get_openai_agent_executor
from app.agent_types.tools_agent import get_tools_agent_executor
from app.agent_types.xml_agent import get_xml_agent_executor
from app.chatbot import get_chatbot_executor
from app.checkpoint import PostgresCheckpoint
Expand Down Expand Up @@ -80,22 +79,22 @@ def get_agent_executor(
):
if agent == AgentType.GPT_35_TURBO:
llm = get_openai_llm()
return get_openai_agent_executor(
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.GPT_4:
llm = get_openai_llm(gpt_4=True)
return get_openai_agent_executor(
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
return get_openai_agent_executor(
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.CLAUDE2:
llm = get_anthropic_llm()
return get_xml_agent_executor(
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.BEDROCK_CLAUDE2:
Expand All @@ -105,12 +104,12 @@ def get_agent_executor(
)
elif agent == AgentType.GEMINI:
llm = get_google_llm()
return get_google_agent_executor(
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.OLLAMA:
llm = get_ollama_llm()
return get_openai_agent_executor(
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)

Expand Down
111 changes: 0 additions & 111 deletions backend/app/agent_types/google_agent.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import json
from typing import cast

from langchain.tools import BaseTool
from langchain.tools.render import format_tool_to_openai_tool
from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import SystemMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
Expand All @@ -12,7 +17,7 @@
from app.message_types import LiberalToolMessage


def get_openai_agent_executor(
def get_tools_agent_executor(
tools: list[BaseTool],
llm: LanguageModelLike,
system_message: str,
Expand All @@ -27,13 +32,16 @@ async def _get_messages(messages):
_dict["content"] = str(_dict["content"])
m_c = ToolMessage(**_dict)
msgs.append(m_c)
elif isinstance(m, FunctionMessage):
# anthropic doesn't like function messages
msgs.append(HumanMessage(content=str(m.content)))
else:
msgs.append(m)

return [SystemMessage(content=system_message)] + msgs

if tools:
llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools])
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm
agent = _get_messages | llm_with_tools
Expand All @@ -43,7 +51,7 @@ async def _get_messages(messages):
def should_continue(messages):
last_message = messages[-1]
# If there is no function call, then we finish
if "tool_calls" not in last_message.additional_kwargs:
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
Expand All @@ -54,16 +62,13 @@ async def call_tool(messages):
actions: list[ToolInvocation] = []
# Based on the continue condition
# we know the last message involves a function call
last_message = messages[-1]
for tool_call in last_message.additional_kwargs["tool_calls"]:
function = tool_call["function"]
function_name = function["name"]
_tool_input = json.loads(function["arguments"] or "{}")
# We construct an ToolInvocation from the function_call
last_message = cast(AIMessage, messages[-1])
for tool_call in last_message.tool_calls:
# We construct a ToolInvocation from the function_call
actions.append(
ToolInvocation(
tool=function_name,
tool_input=_tool_input,
tool=tool_call["name"],
tool_input=tool_call["args"],
)
)
# We call the tool_executor and get back a response
Expand All @@ -72,12 +77,10 @@ async def call_tool(messages):
tool_messages = [
LiberalToolMessage(
tool_call_id=tool_call["id"],
name=tool_call["name"],
content=response,
additional_kwargs={"name": tool_call["function"]["name"]},
)
for tool_call, response in zip(
last_message.additional_kwargs["tool_calls"], responses
)
for tool_call, response in zip(last_message.tool_calls, responses)
]
return tool_messages

Expand Down
15 changes: 12 additions & 3 deletions backend/app/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
from datetime import datetime
from typing import AsyncIterator, Optional

from langchain_core.messages import BaseMessage
from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple

from app.lifespan import get_pg_pool


def loads(value: bytes) -> Checkpoint:
loaded: Checkpoint = pickle.loads(value)
for key, value in loaded["channel_values"].items():
if isinstance(value, list) and all(isinstance(v, BaseMessage) for v in value):
loaded["channel_values"][key] = [v.__class__(**v.__dict__) for v in value]
return loaded


class PostgresCheckpoint(BaseCheckpointSaver):
class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -47,7 +56,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
"thread_ts": value[1],
}
},
pickle.loads(value[0]),
loads(value[0]),
{
"configurable": {
"thread_id": thread_id,
Expand All @@ -70,7 +79,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
):
return CheckpointTuple(
config,
pickle.loads(value[0]),
loads(value[0]),
{
"configurable": {
"thread_id": thread_id,
Expand All @@ -92,7 +101,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"thread_ts": value[1],
}
},
pickle.loads(value[0]),
loads(value[0]),
{
"configurable": {
"thread_id": thread_id,
Expand Down
9 changes: 7 additions & 2 deletions backend/app/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import boto3
import httpx
from langchain_community.chat_models import BedrockChat, ChatAnthropic, ChatFireworks
from langchain_community.chat_models import BedrockChat, ChatFireworks
from langchain_community.chat_models.ollama import ChatOllama
from langchain_google_vertexai import ChatVertexAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_anthropic import ChatAnthropic

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,7 +68,11 @@ def get_anthropic_llm(bedrock: bool = False):
)
model = BedrockChat(model_id="anthropic.claude-v2", client=client)
else:
model = ChatAnthropic(temperature=0, max_tokens_to_sample=2000)
model = ChatAnthropic(
model_name="claude-3-haiku-20240307",
max_tokens_to_sample=2000,
temperature=0,
)
return model


Expand Down
Loading

0 comments on commit cb39b9b

Please sign in to comment.