Skip to content

Commit

Permalink
Merge pull request #175 from langchain-ai/nc/interrupt-before-action
Browse files Browse the repository at this point in the history
Add option to prompt user to continue before executing each tool
  • Loading branch information
nfcampos authored Feb 5, 2024
2 parents bd527f6 + 6683dd7 commit daf9aba
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 15 deletions.
36 changes: 29 additions & 7 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,38 @@ def get_agent_executor(
tools: list,
agent: AgentType,
system_message: str,
interrupt_before_action: bool,
):
if agent == AgentType.GPT_35_TURBO:
llm = get_openai_llm()
return get_openai_agent_executor(tools, llm, system_message, CHECKPOINTER)
return get_openai_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.GPT_4:
llm = get_openai_llm(gpt_4=True)
return get_openai_agent_executor(tools, llm, system_message, CHECKPOINTER)
return get_openai_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
return get_openai_agent_executor(tools, llm, system_message, CHECKPOINTER)
return get_openai_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.CLAUDE2:
llm = get_anthropic_llm()
return get_xml_agent_executor(tools, llm, system_message, CHECKPOINTER)
return get_xml_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.BEDROCK_CLAUDE2:
llm = get_anthropic_llm(bedrock=True)
return get_xml_agent_executor(tools, llm, system_message, CHECKPOINTER)
return get_xml_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.GEMINI:
llm = get_google_llm()
return get_google_agent_executor(tools, llm, system_message, CHECKPOINTER)
return get_google_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
else:
raise ValueError("Unexpected agent type")

Expand All @@ -77,6 +90,7 @@ class ConfigurableAgent(RunnableBinding):
agent: AgentType
system_message: str = DEFAULT_SYSTEM_MESSAGE
retrieval_description: str = RETRIEVAL_DESCRIPTION
interrupt_before_action: bool = False
assistant_id: Optional[str] = None
user_id: Optional[str] = None

Expand All @@ -88,6 +102,7 @@ def __init__(
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
retrieval_description: str = RETRIEVAL_DESCRIPTION,
interrupt_before_action: bool = False,
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[Mapping[str, Any]] = None,
**others: Any,
Expand All @@ -107,7 +122,9 @@ def __init__(
_tools.extend(_returned_tools)
else:
_tools.append(_returned_tools)
_agent = get_agent_executor(_tools, agent, system_message)
_agent = get_agent_executor(
_tools, agent, system_message, interrupt_before_action
)
agent_executor = _agent.with_config({"recursion_limit": 50})
super().__init__(
tools=tools,
Expand Down Expand Up @@ -257,6 +274,11 @@ def __init__(
.configurable_fields(
agent=ConfigurableField(id="agent_type", name="Agent Type"),
system_message=ConfigurableField(id="system_message", name="Instructions"),
interrupt_before_action=ConfigurableField(
id="interrupt_before_action",
name="Tool Confirmation",
description="If Yes, you'll be prompted to continue before each tool is executed.\nIf No, tools will be executed automatically by the agent.",
),
assistant_id=ConfigurableField(
id="assistant_id", name="Assistant ID", is_shared=True
),
Expand Down
3 changes: 3 additions & 0 deletions backend/app/agent_types/google_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def get_google_agent_executor(
tools: list[BaseTool],
llm: LanguageModelLike,
system_message: str,
interrupt_before_action: bool,
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
Expand Down Expand Up @@ -105,4 +106,6 @@ async def call_tool(messages):
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile(checkpointer=checkpoint)
if interrupt_before_action:
app.interrupt = ["action:inbox"]
return app
3 changes: 3 additions & 0 deletions backend/app/agent_types/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def get_openai_agent_executor(
tools: list[BaseTool],
llm: LanguageModelLike,
system_message: str,
interrupt_before_action: bool,
checkpoint: BaseCheckpointSaver,
):
async def _get_messages(messages):
Expand Down Expand Up @@ -119,4 +120,6 @@ async def call_tool(messages):
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile(checkpointer=checkpoint)
if interrupt_before_action:
app.interrupt = ["action:inbox"]
return app
3 changes: 3 additions & 0 deletions backend/app/agent_types/xml_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_xml_agent_executor(
tools: list[BaseTool],
llm: LanguageModelLike,
system_message: str,
interrupt_before_action: bool,
checkpoint: BaseCheckpointSaver,
):
formatted_system_message = xml_template.format(
Expand Down Expand Up @@ -153,4 +154,6 @@ async def call_tool(messages):
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile(checkpointer=checkpoint)
if interrupt_before_action:
app.interrupt = ["action:inbox"]
return app
4 changes: 2 additions & 2 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_thread(user_id: str, thread_id: str) -> Thread | None:
def get_thread_messages(user_id: str, thread_id: str):
"""Get all messages for a thread."""
config = {"configurable": {"user_id": user_id, "thread_id": thread_id}}
app = get_agent_executor([], AgentType.GPT_35_TURBO, "")
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
checkpoint = app.checkpointer.get(config) or empty_checkpoint()
with ChannelsManager(app.channels, checkpoint) as channels:
return {
Expand All @@ -158,7 +158,7 @@ def get_thread_messages(user_id: str, thread_id: str):
def post_thread_messages(user_id: str, thread_id: str, messages: Sequence[AnyMessage]):
"""Add messages to a thread."""
config = {"configurable": {"user_id": user_id, "thread_id": thread_id}}
app = get_agent_executor([], AgentType.GPT_35_TURBO, "")
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
checkpoint = app.checkpointer.get(config) or empty_checkpoint()
with ChannelsManager(app.channels, checkpoint) as channels:
channel = channels[MESSAGES_CHANNEL_NAME]
Expand Down
54 changes: 48 additions & 6 deletions frontend/src/components/Config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,18 @@ function Types(props: {
);
}

function Label(props: { id?: string; title: string }) {
function Label(props: { id?: string; title: string; description?: string }) {
return (
<label
htmlFor={props.id}
className="block font-medium leading-6 text-gray-400 mb-2"
className="flex flex-col font-medium leading-6 text-gray-400 mb-2"
>
{props.title}
<div>{props.title}</div>
{props.description && (
<div className="font-normal text-sm text-gray-600 whitespace-pre-line">
{props.description}
</div>
)}
</label>
);
}
Expand All @@ -84,7 +89,11 @@ function StringField(props: {
}) {
return (
<div>
<Label id={props.id} title={props.title} />
<Label
id={props.id}
title={props.title}
description={props.field.description}
/>
<textarea
rows={4}
name={props.id}
Expand All @@ -109,7 +118,11 @@ export default function SingleOptionField(props: {
}) {
return (
<div>
<Label id={props.id} title={props.field.title} />
<Label
id={props.id}
title={props.field.title}
description={props.field.description}
/>
<fieldset>
<legend className="sr-only">{props.field.title}</legend>
<div className="space-y-2">
Expand Down Expand Up @@ -170,7 +183,11 @@ function MultiOptionField(props: {
}) {
return (
<fieldset>
<Label id={props.id} title={props.title ?? props.field.items?.title} />
<Label
id={props.id}
title={props.title ?? props.field.items?.title}
description={props.field.description}
/>
<div className="space-y-2">
{orderBy(props.field.items?.enum)?.map((option) => (
<div className="relative flex items-start" key={option}>
Expand Down Expand Up @@ -285,6 +302,7 @@ function fileId(file: File) {
const ORDER = [
"system_message",
"retrieval_description",
"interrupt_before_action",
"tools",
"llm_type",
"agent_type",
Expand Down Expand Up @@ -483,6 +501,30 @@ export function Config(props: {
readonly={readonly}
/>
);
} else if (value.type === "boolean") {
return (
<SingleOptionField
key={key}
id={key}
field={{
...value,
type: "string",
enum: ["Yes", "No"],
}}
title={title}
value={values?.configurable?.[key] ? "Yes" : "No"}
setValue={(value: string) =>
setValues({
...values,
configurable: {
...values!.configurable,
[key]: value === "Yes",
},
})
}
readonly={readonly}
/>
);
} else if (
value.type === "array" &&
value.items?.type === "string" &&
Expand Down

0 comments on commit daf9aba

Please sign in to comment.