diff --git a/backend/app/agent.py b/backend/app/agent.py index 7b75b88c..d9e885ce 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -49,25 +49,38 @@ def get_agent_executor( tools: list, agent: AgentType, system_message: str, + interrupt_before_action: bool, ): if agent == AgentType.GPT_35_TURBO: llm = get_openai_llm() - return get_openai_agent_executor(tools, llm, system_message, CHECKPOINTER) + return get_openai_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) elif agent == AgentType.GPT_4: llm = get_openai_llm(gpt_4=True) - return get_openai_agent_executor(tools, llm, system_message, CHECKPOINTER) + return get_openai_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) elif agent == AgentType.AZURE_OPENAI: llm = get_openai_llm(azure=True) - return get_openai_agent_executor(tools, llm, system_message, CHECKPOINTER) + return get_openai_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) elif agent == AgentType.CLAUDE2: llm = get_anthropic_llm() - return get_xml_agent_executor(tools, llm, system_message, CHECKPOINTER) + return get_xml_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) elif agent == AgentType.BEDROCK_CLAUDE2: llm = get_anthropic_llm(bedrock=True) - return get_xml_agent_executor(tools, llm, system_message, CHECKPOINTER) + return get_xml_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) elif agent == AgentType.GEMINI: llm = get_google_llm() - return get_google_agent_executor(tools, llm, system_message, CHECKPOINTER) + return get_google_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) else: raise ValueError("Unexpected agent type") @@ -77,6 +90,7 @@ class ConfigurableAgent(RunnableBinding): agent: AgentType system_message: str = DEFAULT_SYSTEM_MESSAGE retrieval_description: str = RETRIEVAL_DESCRIPTION + interrupt_before_action: bool = False assistant_id: Optional[str] = None user_id: Optional[str] = None @@ -88,6 +102,7 @@ def __init__( system_message: str = DEFAULT_SYSTEM_MESSAGE, assistant_id: Optional[str] = None, retrieval_description: str = RETRIEVAL_DESCRIPTION, + interrupt_before_action: bool = False, kwargs: Optional[Mapping[str, Any]] = None, config: Optional[Mapping[str, Any]] = None, **others: Any, @@ -107,7 +122,9 @@ def __init__( _tools.extend(_returned_tools) else: _tools.append(_returned_tools) - _agent = get_agent_executor(_tools, agent, system_message) + _agent = get_agent_executor( + _tools, agent, system_message, interrupt_before_action + ) agent_executor = _agent.with_config({"recursion_limit": 50}) super().__init__( tools=tools, @@ -257,6 +274,11 @@ def __init__( .configurable_fields( agent=ConfigurableField(id="agent_type", name="Agent Type"), system_message=ConfigurableField(id="system_message", name="Instructions"), + interrupt_before_action=ConfigurableField( + id="interrupt_before_action", + name="Tool Confirmation", + description="If Yes, you'll be prompted to continue before each tool is executed.\nIf No, tools will be executed automatically by the agent.", + ), assistant_id=ConfigurableField( id="assistant_id", name="Assistant ID", is_shared=True ), diff --git a/backend/app/agent_types/google_agent.py b/backend/app/agent_types/google_agent.py index 499dfbf2..3ba2c068 100644 --- a/backend/app/agent_types/google_agent.py +++ b/backend/app/agent_types/google_agent.py @@ -16,6 +16,7 @@ def get_google_agent_executor( tools: list[BaseTool], llm: LanguageModelLike, system_message: str, + interrupt_before_action: bool, checkpoint: BaseCheckpointSaver, ): def _get_messages(messages): @@ -105,4 +106,6 @@ async def call_tool(messages): # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile(checkpointer=checkpoint) + if interrupt_before_action: + app.interrupt = ["action:inbox"] return app diff --git a/backend/app/agent_types/openai_agent.py b/backend/app/agent_types/openai_agent.py index 643c10c1..5de2ac8a 100644 --- a/backend/app/agent_types/openai_agent.py +++ b/backend/app/agent_types/openai_agent.py @@ -16,6 +16,7 @@ def get_openai_agent_executor( tools: list[BaseTool], llm: LanguageModelLike, system_message: str, + interrupt_before_action: bool, checkpoint: BaseCheckpointSaver, ): async def _get_messages(messages): @@ -119,4 +120,6 @@ async def call_tool(messages): # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile(checkpointer=checkpoint) + if interrupt_before_action: + app.interrupt = ["action:inbox"] return app diff --git a/backend/app/agent_types/xml_agent.py b/backend/app/agent_types/xml_agent.py index 6e194f76..256bcb42 100644 --- a/backend/app/agent_types/xml_agent.py +++ b/backend/app/agent_types/xml_agent.py @@ -63,6 +63,7 @@ def get_xml_agent_executor( tools: list[BaseTool], llm: LanguageModelLike, system_message: str, + interrupt_before_action: bool, checkpoint: BaseCheckpointSaver, ): formatted_system_message = xml_template.format( @@ -153,4 +154,6 @@ async def call_tool(messages): # This compiles it into a LangChain Runnable, # meaning you can use it as you would any other runnable app = workflow.compile(checkpointer=checkpoint) + if interrupt_before_action: + app.interrupt = ["action:inbox"] return app diff --git a/backend/app/storage.py b/backend/app/storage.py index c0a9d9bb..a8b1a620 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -144,7 +144,7 @@ def get_thread(user_id: str, thread_id: str) -> Thread | None: def get_thread_messages(user_id: str, thread_id: str): """Get all messages for a thread.""" config = {"configurable": {"user_id": user_id, "thread_id": thread_id}} - app = get_agent_executor([], AgentType.GPT_35_TURBO, "") + app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) checkpoint = app.checkpointer.get(config) or empty_checkpoint() with ChannelsManager(app.channels, checkpoint) as channels: return { @@ -158,7 +158,7 @@ def get_thread_messages(user_id: str, thread_id: str): def post_thread_messages(user_id: str, thread_id: str, messages: Sequence[AnyMessage]): """Add messages to a thread.""" config = {"configurable": {"user_id": user_id, "thread_id": thread_id}} - app = get_agent_executor([], AgentType.GPT_35_TURBO, "") + app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) checkpoint = app.checkpointer.get(config) or empty_checkpoint() with ChannelsManager(app.channels, checkpoint) as channels: channel = channels[MESSAGES_CHANNEL_NAME] diff --git a/frontend/src/components/Config.tsx b/frontend/src/components/Config.tsx index 36f40095..b3d25862 100644 --- a/frontend/src/components/Config.tsx +++ b/frontend/src/components/Config.tsx @@ -63,13 +63,18 @@ function Types(props: { ); } -function Label(props: { id?: string; title: string }) { +function Label(props: { id?: string; title: string; description?: string }) { return ( ); } @@ -84,7 +89,11 @@ function StringField(props: { }) { return (