-
Notifications
You must be signed in to change notification settings - Fork 859
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #132 from langchain-ai/nc/jan18/update-langgraph
Update to use langgraph and langchain 0.1
- Loading branch information
Showing
42 changed files
with
1,642 additions
and
6,545 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from enum import Enum | ||
from typing import Any, Mapping, Optional, Sequence | ||
|
||
from langchain_core.messages import AnyMessage | ||
from langchain_core.runnables import ( | ||
ConfigurableField, | ||
ConfigurableFieldMultiOption, | ||
RunnableBinding, | ||
) | ||
|
||
from app.agent_types.google_agent import get_google_agent_executor | ||
from app.agent_types.openai_agent import get_openai_agent_executor | ||
from app.agent_types.xml_agent import get_xml_agent_executor | ||
from app.checkpoint import RedisCheckpoint | ||
from app.llms import get_anthropic_llm, get_google_llm, get_openai_llm | ||
from app.tools import ( | ||
RETRIEVAL_DESCRIPTION, | ||
TOOL_OPTIONS, | ||
TOOLS, | ||
AvailableTools, | ||
get_retrieval_tool, | ||
) | ||
|
||
|
||
class AgentType(str, Enum): | ||
GPT_35_TURBO = "GPT 3.5 Turbo" | ||
GPT_4 = "GPT 4" | ||
AZURE_OPENAI = "GPT 4 (Azure OpenAI)" | ||
CLAUDE2 = "Claude 2" | ||
BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" | ||
GEMINI = "GEMINI" | ||
|
||
|
||
DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." | ||
|
||
|
||
def get_agent_executor( | ||
tools: list, | ||
agent: AgentType, | ||
system_message: str, | ||
): | ||
checkpointer = RedisCheckpoint() | ||
if agent == AgentType.GPT_35_TURBO: | ||
llm = get_openai_llm() | ||
return get_openai_agent_executor(tools, llm, system_message, checkpointer) | ||
elif agent == AgentType.GPT_4: | ||
llm = get_openai_llm(gpt_4=True) | ||
return get_openai_agent_executor(tools, llm, system_message, checkpointer) | ||
elif agent == AgentType.AZURE_OPENAI: | ||
llm = get_openai_llm(azure=True) | ||
return get_openai_agent_executor(tools, llm, system_message, checkpointer) | ||
elif agent == AgentType.CLAUDE2: | ||
llm = get_anthropic_llm() | ||
return get_xml_agent_executor(tools, llm, system_message, checkpointer) | ||
elif agent == AgentType.BEDROCK_CLAUDE2: | ||
llm = get_anthropic_llm(bedrock=True) | ||
return get_xml_agent_executor(tools, llm, system_message, checkpointer) | ||
elif agent == AgentType.GEMINI: | ||
llm = get_google_llm() | ||
return get_google_agent_executor(tools, llm, system_message, checkpointer) | ||
else: | ||
raise ValueError("Unexpected agent type") | ||
|
||
|
||
class ConfigurableAgent(RunnableBinding): | ||
tools: Sequence[str] | ||
agent: AgentType | ||
system_message: str = DEFAULT_SYSTEM_MESSAGE | ||
retrieval_description: str = RETRIEVAL_DESCRIPTION | ||
assistant_id: Optional[str] = None | ||
user_id: Optional[str] = None | ||
|
||
def __init__( | ||
self, | ||
*, | ||
tools: Sequence[str], | ||
agent: AgentType = AgentType.GPT_35_TURBO, | ||
system_message: str = DEFAULT_SYSTEM_MESSAGE, | ||
assistant_id: Optional[str] = None, | ||
retrieval_description: str = RETRIEVAL_DESCRIPTION, | ||
kwargs: Optional[Mapping[str, Any]] = None, | ||
config: Optional[Mapping[str, Any]] = None, | ||
**others: Any, | ||
) -> None: | ||
others.pop("bound", None) | ||
_tools = [] | ||
for _tool in tools: | ||
if _tool == AvailableTools.RETRIEVAL: | ||
if assistant_id is None: | ||
raise ValueError( | ||
"assistant_id must be provided if Retrieval tool is used" | ||
) | ||
_tools.append(get_retrieval_tool(assistant_id, retrieval_description)) | ||
else: | ||
_tools.append(TOOLS[_tool]()) | ||
_agent = get_agent_executor(_tools, agent, system_message) | ||
agent_executor = _agent.with_config({"recursion_limit": 50}) | ||
super().__init__( | ||
tools=tools, | ||
agent=agent, | ||
system_message=system_message, | ||
retrieval_description=retrieval_description, | ||
bound=agent_executor, | ||
kwargs=kwargs or {}, | ||
config=config or {}, | ||
) | ||
|
||
|
||
agent = ( | ||
ConfigurableAgent( | ||
agent=AgentType.GPT_35_TURBO, | ||
tools=[], | ||
system_message=DEFAULT_SYSTEM_MESSAGE, | ||
retrieval_description=RETRIEVAL_DESCRIPTION, | ||
assistant_id=None, | ||
) | ||
.configurable_fields( | ||
agent=ConfigurableField(id="agent_type", name="Agent Type"), | ||
system_message=ConfigurableField(id="system_message", name="System Message"), | ||
assistant_id=ConfigurableField( | ||
id="assistant_id", name="Assistant ID", is_shared=True | ||
), | ||
tools=ConfigurableFieldMultiOption( | ||
id="tools", | ||
name="Tools", | ||
options=TOOL_OPTIONS, | ||
default=[], | ||
), | ||
retrieval_description=ConfigurableField( | ||
id="retrieval_description", name="Retrieval Description" | ||
), | ||
) | ||
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage]) | ||
) | ||
|
||
if __name__ == "__main__": | ||
import asyncio | ||
|
||
from langchain.schema.messages import HumanMessage | ||
|
||
async def run(): | ||
async for m in agent.astream_events( | ||
HumanMessage(content="whats your name"), | ||
config={"configurable": {"user_id": "2", "thread_id": "test1"}}, | ||
version="v1", | ||
): | ||
print(m) | ||
|
||
asyncio.run(run()) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import json | ||
|
||
from langchain.schema.messages import FunctionMessage | ||
from langchain.tools import BaseTool | ||
from langchain_core.language_models.base import LanguageModelLike | ||
from langchain_core.messages import SystemMessage | ||
from langgraph.checkpoint import BaseCheckpointSaver | ||
from langgraph.graph import END | ||
from langgraph.graph.message import MessageGraph | ||
from langgraph.prebuilt import ToolExecutor, ToolInvocation | ||
|
||
|
||
def get_google_agent_executor( | ||
tools: list[BaseTool], | ||
llm: LanguageModelLike, | ||
system_message: str, | ||
checkpoint: BaseCheckpointSaver, | ||
): | ||
def _get_messages(messages): | ||
return [SystemMessage(content=system_message)] + messages | ||
|
||
if tools: | ||
llm_with_tools = llm.bind(functions=tools) | ||
else: | ||
llm_with_tools = llm | ||
agent = _get_messages | llm_with_tools | ||
tool_executor = ToolExecutor(tools) | ||
|
||
# Define the function that determines whether to continue or not | ||
def should_continue(messages): | ||
last_message = messages[-1] | ||
# If there is no function call, then we finish | ||
if "function_call" not in last_message.additional_kwargs: | ||
return "end" | ||
# Otherwise if there is, we continue | ||
else: | ||
return "continue" | ||
|
||
# Define the function to execute tools | ||
async def call_tool(messages): | ||
# Based on the continue condition | ||
# we know the last message involves a function call | ||
last_message = messages[-1] | ||
# We construct an ToolInvocation from the function_call | ||
action = ToolInvocation( | ||
tool=last_message.additional_kwargs["function_call"]["name"], | ||
tool_input=json.loads( | ||
last_message.additional_kwargs["function_call"]["arguments"] | ||
), | ||
) | ||
# We call the tool_executor and get back a response | ||
response = await tool_executor.ainvoke(action) | ||
# We use the response to create a FunctionMessage | ||
function_message = FunctionMessage(content=str(response), name=action.tool) | ||
# We return a list, because this will get added to the existing list | ||
return function_message | ||
|
||
workflow = MessageGraph() | ||
|
||
# Define the two nodes we will cycle between | ||
workflow.add_node("agent", agent) | ||
workflow.add_node("action", call_tool) | ||
|
||
# Set the entrypoint as `agent` | ||
# This means that this node is the first one called | ||
workflow.set_entry_point("agent") | ||
|
||
# We now add a conditional edge | ||
workflow.add_conditional_edges( | ||
# First, we define the start node. We use `agent`. | ||
# This means these are the edges taken after the `agent` node is called. | ||
"agent", | ||
# Next, we pass in the function that will determine which node is called next. | ||
should_continue, | ||
# Finally we pass in a mapping. | ||
# The keys are strings, and the values are other nodes. | ||
# END is a special node marking that the graph should finish. | ||
# What will happen is we will call `should_continue`, and then the output of that | ||
# will be matched against the keys in this mapping. | ||
# Based on which one it matches, that node will then be called. | ||
{ | ||
# If `tools`, then we call the tool node. | ||
"continue": "action", | ||
# Otherwise we finish. | ||
"end": END, | ||
}, | ||
) | ||
|
||
# We now add a normal edge from `tools` to `agent`. | ||
# This means that after `tools` is called, `agent` node is called next. | ||
workflow.add_edge("action", "agent") | ||
|
||
# Finally, we compile it! | ||
# This compiles it into a LangChain Runnable, | ||
# meaning you can use it as you would any other runnable | ||
app = workflow.compile(checkpointer=checkpoint) | ||
return app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import json | ||
|
||
from langchain.schema.messages import FunctionMessage | ||
from langchain.tools import BaseTool | ||
from langchain.tools.render import format_tool_to_openai_function | ||
from langchain_core.language_models.base import LanguageModelLike | ||
from langchain_core.messages import SystemMessage | ||
from langgraph.checkpoint import BaseCheckpointSaver | ||
from langgraph.graph import END | ||
from langgraph.graph.message import MessageGraph | ||
from langgraph.prebuilt import ToolExecutor, ToolInvocation | ||
|
||
|
||
def get_openai_agent_executor( | ||
tools: list[BaseTool], | ||
llm: LanguageModelLike, | ||
system_message: str, | ||
checkpoint: BaseCheckpointSaver, | ||
): | ||
def _get_messages(messages): | ||
return [SystemMessage(content=system_message)] + messages | ||
|
||
if tools: | ||
llm_with_tools = llm.bind( | ||
functions=[format_tool_to_openai_function(t) for t in tools] | ||
) | ||
else: | ||
llm_with_tools = llm | ||
agent = _get_messages | llm_with_tools | ||
tool_executor = ToolExecutor(tools) | ||
|
||
# Define the function that determines whether to continue or not | ||
def should_continue(messages): | ||
last_message = messages[-1] | ||
# If there is no function call, then we finish | ||
if "function_call" not in last_message.additional_kwargs: | ||
return "end" | ||
# Otherwise if there is, we continue | ||
else: | ||
return "continue" | ||
|
||
# Define the function to execute tools | ||
async def call_tool(messages): | ||
# Based on the continue condition | ||
# we know the last message involves a function call | ||
last_message = messages[-1] | ||
# We construct an ToolInvocation from the function_call | ||
action = ToolInvocation( | ||
tool=last_message.additional_kwargs["function_call"]["name"], | ||
tool_input=json.loads( | ||
last_message.additional_kwargs["function_call"]["arguments"] | ||
), | ||
) | ||
# We call the tool_executor and get back a response | ||
response = await tool_executor.ainvoke(action) | ||
# We use the response to create a FunctionMessage | ||
function_message = FunctionMessage(content=str(response), name=action.tool) | ||
# We return a list, because this will get added to the existing list | ||
return function_message | ||
|
||
workflow = MessageGraph() | ||
|
||
# Define the two nodes we will cycle between | ||
workflow.add_node("agent", agent) | ||
workflow.add_node("action", call_tool) | ||
|
||
# Set the entrypoint as `agent` | ||
# This means that this node is the first one called | ||
workflow.set_entry_point("agent") | ||
|
||
# We now add a conditional edge | ||
workflow.add_conditional_edges( | ||
# First, we define the start node. We use `agent`. | ||
# This means these are the edges taken after the `agent` node is called. | ||
"agent", | ||
# Next, we pass in the function that will determine which node is called next. | ||
should_continue, | ||
# Finally we pass in a mapping. | ||
# The keys are strings, and the values are other nodes. | ||
# END is a special node marking that the graph should finish. | ||
# What will happen is we will call `should_continue`, and then the output of that | ||
# will be matched against the keys in this mapping. | ||
# Based on which one it matches, that node will then be called. | ||
{ | ||
# If `tools`, then we call the tool node. | ||
"continue": "action", | ||
# Otherwise we finish. | ||
"end": END, | ||
}, | ||
) | ||
|
||
# We now add a normal edge from `tools` to `agent`. | ||
# This means that after `tools` is called, `agent` node is called next. | ||
workflow.add_edge("action", "agent") | ||
|
||
# Finally, we compile it! | ||
# This compiles it into a LangChain Runnable, | ||
# meaning you can use it as you would any other runnable | ||
app = workflow.compile(checkpointer=checkpoint) | ||
return app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
xml_template = """{system_message} | ||
You have access to the following tools: | ||
{tools} | ||
In order to use a tool, you can use <tool></tool> and <tool_input></tool_input> tags. You will then get back a response in the form <observation></observation> | ||
For example, if you have a tool called 'search' that could run a google search, in order to search for the weather in SF you would respond: | ||
<tool>search</tool><tool_input>weather in SF</tool_input> | ||
<observation>64 degrees</observation> | ||
When you are done, you can respond as normal to the user. | ||
Example 1: | ||
Human: Hi! | ||
Assistant: Hi! How are you? | ||
Human: What is the weather in SF? | ||
Assistant: <tool>search</tool><tool_input>weather in SF</tool_input> | ||
<observation>64 degrees</observation> | ||
It is 64 degrees in SF | ||
Begin!""" |
Oops, something went wrong.