From 2fe719eb04f9de1431346f2b43c5eb5e6a163c29 Mon Sep 17 00:00:00 2001 From: Bakar Tavadze Date: Mon, 27 May 2024 18:03:08 +0400 Subject: [PATCH 1/4] Add a node to all cognitive architectures that uniformly signifies the finish point of the graph workflow. --- backend/app/agent_types/constants.py | 2 ++ backend/app/agent_types/tools_agent.py | 6 ++++-- backend/app/chatbot.py | 5 ++++- backend/app/retrieval.py | 6 ++++-- 4 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 backend/app/agent_types/constants.py diff --git a/backend/app/agent_types/constants.py b/backend/app/agent_types/constants.py new file mode 100644 index 00000000..7032441f --- /dev/null +++ b/backend/app/agent_types/constants.py @@ -0,0 +1,2 @@ +FINISH_NODE_KEY = "finish" +FINISH_NODE_ACTION = lambda _: None # noqa: E731 diff --git a/backend/app/agent_types/tools_agent.py b/backend/app/agent_types/tools_agent.py index 0a061af1..1775e07a 100644 --- a/backend/app/agent_types/tools_agent.py +++ b/backend/app/agent_types/tools_agent.py @@ -10,10 +10,10 @@ ToolMessage, ) from langgraph.checkpoint import BaseCheckpointSaver -from langgraph.graph import END from langgraph.graph.message import MessageGraph from langgraph.prebuilt import ToolExecutor, ToolInvocation +from app.agent_types.constants import FINISH_NODE_ACTION, FINISH_NODE_KEY from app.message_types import LiberalToolMessage @@ -89,10 +89,12 @@ async def call_tool(messages): # Define the two nodes we will cycle between workflow.add_node("agent", agent) workflow.add_node("action", call_tool) + workflow.add_node(FINISH_NODE_KEY, FINISH_NODE_ACTION) # Set the entrypoint as `agent` # This means that this node is the first one called workflow.set_entry_point("agent") + workflow.set_finish_point(FINISH_NODE_KEY) # We now add a conditional edge workflow.add_conditional_edges( @@ -111,7 +113,7 @@ async def call_tool(messages): # If `tools`, then we call the tool node. "continue": "action", # Otherwise we finish. - "end": END, + "end": FINISH_NODE_KEY, }, ) diff --git a/backend/app/chatbot.py b/backend/app/chatbot.py index eeb5b787..59fe240a 100644 --- a/backend/app/chatbot.py +++ b/backend/app/chatbot.py @@ -5,6 +5,7 @@ from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph.state import StateGraph +from app.agent_types.constants import FINISH_NODE_ACTION, FINISH_NODE_KEY from app.message_types import add_messages_liberal @@ -20,7 +21,9 @@ def _get_messages(messages): workflow = StateGraph(Annotated[List[BaseMessage], add_messages_liberal]) workflow.add_node("chatbot", chatbot) + workflow.add_node(FINISH_NODE_KEY, FINISH_NODE_ACTION) workflow.set_entry_point("chatbot") - workflow.set_finish_point("chatbot") + workflow.set_finish_point(FINISH_NODE_KEY) + workflow.add_edge("chatbot", FINISH_NODE_KEY) app = workflow.compile(checkpointer=checkpoint) return app diff --git a/backend/app/retrieval.py b/backend/app/retrieval.py index 763d2cd2..7d310152 100644 --- a/backend/app/retrieval.py +++ b/backend/app/retrieval.py @@ -8,9 +8,9 @@ from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import chain from langgraph.checkpoint import BaseCheckpointSaver -from langgraph.graph import END from langgraph.graph.state import StateGraph +from app.agent_types.constants import FINISH_NODE_ACTION, FINISH_NODE_KEY from app.message_types import LiberalToolMessage, add_messages_liberal search_prompt = PromptTemplate.from_template( @@ -132,9 +132,11 @@ def call_model(state: AgentState): workflow.add_node("invoke_retrieval", invoke_retrieval) workflow.add_node("retrieve", retrieve) workflow.add_node("response", call_model) + workflow.add_node(FINISH_NODE_KEY, FINISH_NODE_ACTION) workflow.set_entry_point("invoke_retrieval") + workflow.set_finish_point(FINISH_NODE_KEY) workflow.add_edge("invoke_retrieval", "retrieve") workflow.add_edge("retrieve", "response") - workflow.add_edge("response", END) + workflow.add_edge("response", FINISH_NODE_KEY) app = workflow.compile(checkpointer=checkpoint) return app From a49fb225c40dede992607aa1bf14d91700b7b9d2 Mon Sep 17 00:00:00 2001 From: Bakar Tavadze Date: Mon, 27 May 2024 18:04:00 +0400 Subject: [PATCH 2/4] Enable setting as_node when updating state. --- backend/app/storage.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend/app/storage.py b/backend/app/storage.py index 17b6aebc..fc89ec51 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -5,6 +5,7 @@ from langchain_core.runnables import RunnableConfig from app.agent import agent +from app.agent_types.constants import FINISH_NODE_KEY from app.lifespan import get_pg_pool from app.schema import Assistant, Thread, User @@ -125,6 +126,7 @@ async def update_thread_state( *, user_id: str, assistant: Assistant, + as_node: Optional[str] = FINISH_NODE_KEY, ): """Add state to a thread.""" await agent.aupdate_state( @@ -136,6 +138,7 @@ async def update_thread_state( } }, values, + as_node=as_node, ) From ef377e1428d02deb1ce7b0c4808b8830caa576c5 Mon Sep 17 00:00:00 2001 From: Bakar Tavadze Date: Mon, 27 May 2024 18:05:24 +0400 Subject: [PATCH 3/4] Separate thread POST and PUT payloads. --- backend/app/api/threads.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index e887791a..64e3065c 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Path from langchain.schema.messages import AnyMessage +from langchain_core.messages import AIMessage from pydantic import BaseModel, Field import app.storage as storage @@ -15,14 +16,21 @@ ThreadID = Annotated[str, Path(description="The ID of the thread.")] -class ThreadPutRequest(BaseModel): +class ThreadPostRequest(BaseModel): """Payload for creating a thread.""" name: str = Field(..., description="The name of the thread.") assistant_id: str = Field(..., description="The ID of the assistant to use.") -class ThreadPostRequest(BaseModel): +class ThreadPutRequest(BaseModel): + """Payload for updating a thread.""" + + name: str = Field(..., description="The name of the thread.") + assistant_id: str = Field(..., description="The ID of the assistant to use.") + + +class ThreadStatePostRequest(BaseModel): """Payload for adding state to a thread.""" values: Union[Sequence[AnyMessage], Dict[str, Any]] @@ -58,7 +66,7 @@ async def get_thread_state( async def add_thread_state( user: AuthedUser, tid: ThreadID, - payload: ThreadPostRequest, + payload: ThreadStatePostRequest, ): """Add state to a thread.""" thread = await storage.get_thread(user["user_id"], tid) @@ -109,14 +117,14 @@ async def get_thread( @router.post("") async def create_thread( user: AuthedUser, - thread_put_request: ThreadPutRequest, + payload: ThreadPostRequest, ) -> Thread: """Create a thread.""" return await storage.put_thread( user["user_id"], str(uuid4()), - assistant_id=thread_put_request.assistant_id, - name=thread_put_request.name, + assistant_id=payload.assistant_id, + name=payload.name, ) @@ -124,14 +132,14 @@ async def create_thread( async def upsert_thread( user: AuthedUser, tid: ThreadID, - thread_put_request: ThreadPutRequest, + payload: ThreadPutRequest, ) -> Thread: """Update a thread.""" return await storage.put_thread( user["user_id"], tid, - assistant_id=thread_put_request.assistant_id, - name=thread_put_request.name, + assistant_id=payload.assistant_id, + name=payload.name, ) From 0276b7d4892d42f0410228b3f5006e21977ab0a3 Mon Sep 17 00:00:00 2001 From: Bakar Tavadze Date: Mon, 27 May 2024 18:17:38 +0400 Subject: [PATCH 4/4] Enable creating a starting AI message when creating a thread. --- backend/app/api/threads.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index 64e3065c..f2b69a0d 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -21,6 +21,9 @@ class ThreadPostRequest(BaseModel): name: str = Field(..., description="The name of the thread.") assistant_id: str = Field(..., description="The ID of the assistant to use.") + starting_message: Optional[str] = Field( + None, description="The starting AI message for the thread." + ) class ThreadPutRequest(BaseModel): @@ -120,12 +123,25 @@ async def create_thread( payload: ThreadPostRequest, ) -> Thread: """Create a thread.""" - return await storage.put_thread( + assistant = await storage.get_assistant(user["user_id"], payload.assistant_id) + if not assistant: + raise HTTPException(status_code=404, detail="Assistant not found") + thread = await storage.put_thread( user["user_id"], str(uuid4()), assistant_id=payload.assistant_id, name=payload.name, ) + if payload.starting_message is not None: + message = AIMessage(id=str(uuid4()), content=payload.starting_message) + chat_retrieval = assistant["config"]["configurable"]["type"] == "chat_retrieval" + await storage.update_thread_state( + {"configurable": {"thread_id": thread["thread_id"]}}, + {"messages": [message]} if chat_retrieval else [message], + user_id=user["user_id"], + assistant=assistant, + ) + return thread @router.put("/{tid}")