diff --git a/backend/app/api/runs.py b/backend/app/api/runs.py index 1dd8d2d3..31c08c95 100644 --- a/backend/app/api/runs.py +++ b/backend/app/api/runs.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence, Union import langsmith.client from fastapi import APIRouter, BackgroundTasks, HTTPException @@ -24,8 +24,8 @@ class CreateRunPayload(BaseModel): """Payload for creating a run.""" thread_id: str - input: Optional[Sequence[AnyMessage]] = Field(default_factory=list) config: Optional[RunnableConfig] = None + input: Optional[Union[Sequence[AnyMessage], Dict]] = Field(default_factory=list) async def _run_input_and_config( diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index 31fe6584..5087f895 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -1,4 +1,4 @@ -from typing import Annotated, List, Sequence +from typing import Annotated, Any, Dict, List, Optional, Sequence, Union from uuid import uuid4 from fastapi import APIRouter, HTTPException, Path @@ -21,10 +21,11 @@ class ThreadPutRequest(BaseModel): assistant_id: str = Field(..., description="The ID of the assistant to use.") -class ThreadMessagesPostRequest(BaseModel): +class ThreadPostRequest(BaseModel): """Payload for adding messages to a thread.""" - messages: Sequence[AnyMessage] + values: Optional[Union[Dict[str, Any], Sequence[AnyMessage]]] + config: Optional[Dict[str, Any]] = None @router.get("/") @@ -33,23 +34,25 @@ async def list_threads(opengpts_user_id: OpengptsUserId) -> List[Thread]: return await storage.list_threads(opengpts_user_id) -@router.get("/{tid}/messages") -async def get_thread_messages( +@router.get("/{tid}/state") +async def get_thread_state( opengpts_user_id: OpengptsUserId, tid: ThreadID, ): """Get all messages for a thread.""" - return await storage.get_thread_messages(opengpts_user_id, tid) + return await storage.get_thread_state(opengpts_user_id, tid) -@router.post("/{tid}/messages") -async def add_thread_messages( +@router.post("/{tid}/state") +async def update_thread_state( + payload: ThreadPostRequest, opengpts_user_id: OpengptsUserId, tid: ThreadID, - payload: ThreadMessagesPostRequest, ): """Add messages to a thread.""" - return await storage.post_thread_messages(opengpts_user_id, tid, payload.messages) + return await storage.update_thread_state( + payload.config or {"configurable": {"thread_id": tid}}, payload.values + ) @router.get("/{tid}/history") diff --git a/backend/app/storage.py b/backend/app/storage.py index a4b0753f..d8c401b1 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -1,7 +1,8 @@ from datetime import datetime, timezone -from typing import List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union from langchain_core.messages import AnyMessage +from langchain_core.runnables import RunnableConfig from app.agent import AgentType, get_agent_executor from app.lifespan import get_pg_pool @@ -98,37 +99,36 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]: ) -async def get_thread_messages(user_id: str, thread_id: str): +async def get_thread_state(user_id: str, thread_id: str): """Get all messages for a thread.""" app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) state = await app.aget_state({"configurable": {"thread_id": thread_id}}) return { - "messages": state.values, - "resumeable": bool(state.next), + "values": state.values, + "next": state.next, } -async def post_thread_messages( - user_id: str, thread_id: str, messages: Sequence[AnyMessage] +async def update_thread_state( + config: RunnableConfig, messages: Union[Sequence[AnyMessage], dict[str, Any]] ): """Add messages to a thread.""" app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) - await app.aupdate_state({"configurable": {"thread_id": thread_id}}, messages) + return await app.aupdate_state(config, messages) async def get_thread_history(user_id: str, thread_id: str): """Get the history of a thread.""" app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False) + config = {"configurable": {"thread_id": thread_id}} return [ { "values": c.values, - "resumeable": bool(c.next), + "next": c.next, "config": c.config, "parent": c.parent_config, } - async for c in app.aget_state_history( - {"configurable": {"thread_id": thread_id}} - ) + async for c in app.aget_state_history(config) ] diff --git a/backend/tests/unit_tests/app/test_app.py b/backend/tests/unit_tests/app/test_app.py index f2bfdc6c..d23ccb67 100644 --- a/backend/tests/unit_tests/app/test_app.py +++ b/backend/tests/unit_tests/app/test_app.py @@ -110,9 +110,9 @@ async def test_threads() -> None: ) assert response.status_code == 200, response.text - response = await client.get(f"/threads/{tid}/messages", headers=headers) + response = await client.get(f"/threads/{tid}/state", headers=headers) assert response.status_code == 200 - assert response.json() == {"messages": [], "resumeable": False} + assert response.json() == {"values": [], "resumeable": False} response = await client.get("/threads/", headers=headers) diff --git a/frontend/package.json b/frontend/package.json index f8f1c7c2..e639aa1e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -2,6 +2,7 @@ "name": "frontend", "private": true, "version": "0.0.0", + "packageManager": "yarn@1.22.19", "type": "module", "scripts": { "dev": "vite --host", @@ -11,9 +12,12 @@ "format": "prettier -w src" }, "dependencies": { + "@emotion/react": "^11.11.4", + "@emotion/styled": "^11.11.0", "@headlessui/react": "^1.7.17", "@heroicons/react": "^2.0.18", "@microsoft/fetch-event-source": "^2.0.1", + "@mui/material": "^5.15.14", "@tailwindcss/forms": "^0.5.6", "@tailwindcss/typography": "^0.5.10", "clsx": "^2.0.0", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 001cae28..46993dc5 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -27,7 +27,11 @@ function App(props: { edit?: boolean }) { const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant(); const startTurn = useCallback( - async (message: MessageWithFiles | null, thread_id: string) => { + async ( + message: MessageWithFiles | null, + thread_id: string, + config?: Record, + ) => { const files = message?.files || []; if (files.length > 0) { const formData = files.reduce((formData, file) => { @@ -56,6 +60,7 @@ function App(props: { edit?: boolean }) { ] : null, thread_id, + config, ); }, [startStream], diff --git a/frontend/src/assets/EmptyState.svg b/frontend/src/assets/EmptyState.svg new file mode 100644 index 00000000..ba37810b --- /dev/null +++ b/frontend/src/assets/EmptyState.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/components/AutosizeTextarea.tsx b/frontend/src/components/AutosizeTextarea.tsx new file mode 100644 index 00000000..ee0b0b89 --- /dev/null +++ b/frontend/src/components/AutosizeTextarea.tsx @@ -0,0 +1,60 @@ +import { Ref } from "react"; +import { cn } from "../utils/cn"; + +const COMMON_CLS = cn( + "text-sm col-[1] row-[1] m-0 resize-none overflow-hidden whitespace-pre-wrap break-words bg-transparent px-2 py-1 rounded shadow-none", +); + +export function AutosizeTextarea(props: { + id?: string; + inputRef?: Ref; + value?: string | null | undefined; + placeholder?: string; + className?: string; + onChange?: (e: string) => void; + onFocus?: () => void; + onBlur?: () => void; + onKeyDown?: (e: React.KeyboardEvent) => void; + autoFocus?: boolean; + readOnly?: boolean; + cursorPointer?: boolean; + disabled?: boolean; + fullHeight?: boolean; +}) { + return ( +
+