Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable creating a starting AI message when creating a thread #348

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/agent_types/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FINISH_NODE_KEY = "finish"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this to all cognitive architectures to have a dedicated node that signifies the end of graph workflow. I know that END exists but doing as_node=END lead to an error that stated that END wasn't part of the nodes.

FINISH_NODE_ACTION = lambda _: None # noqa: E731
6 changes: 4 additions & 2 deletions backend/app/agent_types/tools_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
ToolMessage,
)
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

from app.agent_types.constants import FINISH_NODE_ACTION, FINISH_NODE_KEY
from app.message_types import LiberalToolMessage


Expand Down Expand Up @@ -89,10 +89,12 @@ async def call_tool(messages):
# Define the two nodes we will cycle between
workflow.add_node("agent", agent)
workflow.add_node("action", call_tool)
workflow.add_node(FINISH_NODE_KEY, FINISH_NODE_ACTION)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
workflow.set_finish_point(FINISH_NODE_KEY)

# We now add a conditional edge
workflow.add_conditional_edges(
Expand All @@ -111,7 +113,7 @@ async def call_tool(messages):
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
"end": FINISH_NODE_KEY,
},
)

Expand Down
44 changes: 34 additions & 10 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from fastapi import APIRouter, HTTPException, Path
from langchain.schema.messages import AnyMessage
from langchain_core.messages import AIMessage
from pydantic import BaseModel, Field

import app.storage as storage
Expand All @@ -15,14 +16,24 @@
ThreadID = Annotated[str, Path(description="The ID of the thread.")]


class ThreadPutRequest(BaseModel):
class ThreadPostRequest(BaseModel):
"""Payload for creating a thread."""

name: str = Field(..., description="The name of the thread.")
assistant_id: str = Field(..., description="The ID of the assistant to use.")
starting_message: Optional[str] = Field(
None, description="The starting AI message for the thread."
)


class ThreadPostRequest(BaseModel):
class ThreadPutRequest(BaseModel):
"""Payload for updating a thread."""

name: str = Field(..., description="The name of the thread.")
assistant_id: str = Field(..., description="The ID of the assistant to use.")


class ThreadStatePostRequest(BaseModel):
"""Payload for adding state to a thread."""

values: Union[Sequence[AnyMessage], Dict[str, Any]]
Expand Down Expand Up @@ -58,7 +69,7 @@ async def get_thread_state(
async def add_thread_state(
user: AuthedUser,
tid: ThreadID,
payload: ThreadPostRequest,
payload: ThreadStatePostRequest,
):
"""Add state to a thread."""
thread = await storage.get_thread(user["user_id"], tid)
Expand Down Expand Up @@ -109,29 +120,42 @@ async def get_thread(
@router.post("")
async def create_thread(
user: AuthedUser,
thread_put_request: ThreadPutRequest,
payload: ThreadPostRequest,
) -> Thread:
"""Create a thread."""
return await storage.put_thread(
assistant = await storage.get_assistant(user["user_id"], payload.assistant_id)
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
thread = await storage.put_thread(
user["user_id"],
str(uuid4()),
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
assistant_id=payload.assistant_id,
name=payload.name,
)
if payload.starting_message is not None:
message = AIMessage(id=str(uuid4()), content=payload.starting_message)
chat_retrieval = assistant["config"]["configurable"]["type"] == "chat_retrieval"
await storage.update_thread_state(
{"configurable": {"thread_id": thread["thread_id"]}},
{"messages": [message]} if chat_retrieval else [message],
user_id=user["user_id"],
assistant=assistant,
)
return thread


@router.put("/{tid}")
async def upsert_thread(
user: AuthedUser,
tid: ThreadID,
thread_put_request: ThreadPutRequest,
payload: ThreadPutRequest,
) -> Thread:
"""Update a thread."""
return await storage.put_thread(
user["user_id"],
tid,
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
assistant_id=payload.assistant_id,
name=payload.name,
)


Expand Down
5 changes: 4 additions & 1 deletion backend/app/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph.state import StateGraph

from app.agent_types.constants import FINISH_NODE_ACTION, FINISH_NODE_KEY
from app.message_types import add_messages_liberal


Expand All @@ -20,7 +21,9 @@ def _get_messages(messages):

workflow = StateGraph(Annotated[List[BaseMessage], add_messages_liberal])
workflow.add_node("chatbot", chatbot)
workflow.add_node(FINISH_NODE_KEY, FINISH_NODE_ACTION)
workflow.set_entry_point("chatbot")
workflow.set_finish_point("chatbot")
workflow.set_finish_point(FINISH_NODE_KEY)
workflow.add_edge("chatbot", FINISH_NODE_KEY)
app = workflow.compile(checkpointer=checkpoint)
return app
6 changes: 4 additions & 2 deletions backend/app/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import chain
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.state import StateGraph

from app.agent_types.constants import FINISH_NODE_ACTION, FINISH_NODE_KEY
from app.message_types import LiberalToolMessage, add_messages_liberal

search_prompt = PromptTemplate.from_template(
Expand Down Expand Up @@ -132,9 +132,11 @@ def call_model(state: AgentState):
workflow.add_node("invoke_retrieval", invoke_retrieval)
workflow.add_node("retrieve", retrieve)
workflow.add_node("response", call_model)
workflow.add_node(FINISH_NODE_KEY, FINISH_NODE_ACTION)
workflow.set_entry_point("invoke_retrieval")
workflow.set_finish_point(FINISH_NODE_KEY)
workflow.add_edge("invoke_retrieval", "retrieve")
workflow.add_edge("retrieve", "response")
workflow.add_edge("response", END)
workflow.add_edge("response", FINISH_NODE_KEY)
app = workflow.compile(checkpointer=checkpoint)
return app
3 changes: 3 additions & 0 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain_core.runnables import RunnableConfig

from app.agent import agent
from app.agent_types.constants import FINISH_NODE_KEY
from app.lifespan import get_pg_pool
from app.schema import Assistant, Thread, User

Expand Down Expand Up @@ -125,6 +126,7 @@ async def update_thread_state(
*,
user_id: str,
assistant: Assistant,
as_node: Optional[str] = FINISH_NODE_KEY,
):
"""Add state to a thread."""
await agent.aupdate_state(
Expand All @@ -136,6 +138,7 @@ async def update_thread_state(
}
},
values,
as_node=as_node,
)


Expand Down
Loading