From 90cf905d3edb7401f4e084870b46c4d170c7981a Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 20 Mar 2024 16:23:27 -0700 Subject: [PATCH 1/8] Edit existing messages before continuing --- backend/app/api/threads.py | 33 +-- backend/app/storage.py | 24 +- frontend/package.json | 4 +- frontend/src/App.tsx | 3 +- frontend/src/assets/EmptyState.svg | 21 ++ frontend/src/components/Chat.tsx | 148 ++++++++---- frontend/src/components/JsonEditor.tsx | 67 ++++++ frontend/src/components/Message.tsx | 123 +++------- frontend/src/components/MessageEditor.tsx | 271 ++++++++++++++++++++++ frontend/src/components/StringEditor.tsx | 49 ++++ frontend/src/components/Tool.tsx | 93 ++++++++ frontend/src/hooks/useChatMessages.ts | 38 +-- frontend/src/hooks/useMessageEditing.ts | 49 ++++ frontend/src/hooks/useStreamState.tsx | 4 +- frontend/src/utils/equals.ts | 39 ++++ frontend/yarn.lock | 175 +++++++++++++- 16 files changed, 953 insertions(+), 188 deletions(-) create mode 100644 frontend/src/assets/EmptyState.svg create mode 100644 frontend/src/components/JsonEditor.tsx create mode 100644 frontend/src/components/MessageEditor.tsx create mode 100644 frontend/src/components/StringEditor.tsx create mode 100644 frontend/src/components/Tool.tsx create mode 100644 frontend/src/hooks/useMessageEditing.ts create mode 100644 frontend/src/utils/equals.ts diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index 741ce87b..dd6441b6 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -1,5 +1,4 @@ -import asyncio -from typing import Annotated, Any, Dict, List, Sequence, Union +from typing import Annotated, Any, Dict, List, Optional, Sequence, Union from uuid import uuid4 from fastapi import APIRouter, HTTPException, Path @@ -27,6 +26,7 @@ class ThreadPostRequest(BaseModel): """Payload for adding state to a thread.""" values: Union[Sequence[AnyMessage], Dict[str, Any]] + config: Optional[Dict[str, Any]] = None @router.get("/") @@ -41,13 +41,14 @@ async def get_thread_state( tid: ThreadID, ): """Get state for a thread.""" - thread, state = await asyncio.gather( - storage.get_thread(user["user_id"], tid), - storage.get_thread_state(user["user_id"], tid), - ) + thread = await storage.get_thread(user["user_id"], tid) if not thread: raise HTTPException(status_code=404, detail="Thread not found") - return state + return await storage.get_thread_state( + user_id=user["user_id"], + thread_id=tid, + assistant_id=thread["assistant_id"], + ) @router.post("/{tid}/state") @@ -60,7 +61,12 @@ async def add_thread_state( thread = await storage.get_thread(user["user_id"], tid) if not thread: raise HTTPException(status_code=404, detail="Thread not found") - return await storage.update_thread_state(user["user_id"], tid, payload.values) + return await storage.update_thread_state( + payload.config or {"configurable": {"thread_id": tid}}, + payload.values, + user_id=user["user_id"], + assistant_id=thread["assistant_id"], + ) @router.get("/{tid}/history") @@ -69,13 +75,14 @@ async def get_thread_history( tid: ThreadID, ): """Get all past states for a thread.""" - thread, history = await asyncio.gather( - storage.get_thread(user["user_id"], tid), - storage.get_thread_history(user["user_id"], tid), - ) + thread = await storage.get_thread(user["user_id"], tid) if not thread: raise HTTPException(status_code=404, detail="Thread not found") - return history + return await storage.get_thread_history( + user_id=user["user_id"], + thread_id=tid, + assistant_id=thread["assistant_id"], + ) @router.get("/{tid}") diff --git a/backend/app/storage.py b/backend/app/storage.py index f1878699..fdbc356a 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -1,7 +1,8 @@ from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Union from langchain_core.messages import AnyMessage +from langchain_core.runnables import RunnableConfig from app.agent import agent from app.lifespan import get_pg_pool @@ -98,10 +99,9 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]: ) -async def get_thread_state(user_id: str, thread_id: str): +async def get_thread_state(*, user_id: str, thread_id: str, assistant_id: str): """Get state for a thread.""" - thread = await get_thread(user_id, thread_id) - assistant = await get_assistant(user_id, thread["assistant_id"]) + assistant = await get_assistant(user_id, assistant_id) state = await agent.aget_state( { "configurable": { @@ -117,26 +117,28 @@ async def get_thread_state(user_id: str, thread_id: str): async def update_thread_state( - user_id: str, thread_id: str, values: Union[Sequence[AnyMessage], Dict[str, Any]] + config: RunnableConfig, + values: Union[Sequence[AnyMessage], dict[str, Any]], + *, + user_id: str, + assistant_id: str, ): """Add state to a thread.""" - thread = await get_thread(user_id, thread_id) - assistant = await get_assistant(user_id, thread["assistant_id"]) + assistant = await get_assistant(user_id, assistant_id) await agent.aupdate_state( { "configurable": { **assistant["config"]["configurable"], - "thread_id": thread_id, + **config["configurable"], } }, values, ) -async def get_thread_history(user_id: str, thread_id: str): +async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str): """Get the history of a thread.""" - thread = await get_thread(user_id, thread_id) - assistant = await get_assistant(user_id, thread["assistant_id"]) + assistant = await get_assistant(user_id, assistant_id) return [ { "values": c.values, diff --git a/frontend/package.json b/frontend/package.json index f8f1c7c2..9c7ff20f 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -2,6 +2,7 @@ "name": "frontend", "private": true, "version": "0.0.0", + "packageManager": "yarn@1.22.19", "type": "module", "scripts": { "dev": "vite --host", @@ -11,14 +12,15 @@ "format": "prettier -w src" }, "dependencies": { + "@codemirror/lang-json": "^6.0.1", "@headlessui/react": "^1.7.17", "@heroicons/react": "^2.0.18", "@microsoft/fetch-event-source": "^2.0.1", "@tailwindcss/forms": "^0.5.6", "@tailwindcss/typography": "^0.5.10", + "@uiw/react-codemirror": "^4.21.25", "clsx": "^2.0.0", "dompurify": "^3.0.6", - "fast-json-patch": "^3.1.1", "lodash": "^4.17.21", "marked": "^9.1.5", "react": "^18.2.0", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 514d3588..7190dc46 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -32,6 +32,7 @@ function App(props: { edit?: boolean }) { message: MessageWithFiles | null, thread_id: string, assistantType: string, + config?: Record, ) => { const files = message?.files || []; if (files.length > 0) { @@ -73,7 +74,7 @@ function App(props: { edit?: boolean }) { } } - await startStream(input, thread_id); + await startStream(input, thread_id, config); }, [startStream], ); diff --git a/frontend/src/assets/EmptyState.svg b/frontend/src/assets/EmptyState.svg new file mode 100644 index 00000000..ba37810b --- /dev/null +++ b/frontend/src/assets/EmptyState.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/components/Chat.tsx b/frontend/src/components/Chat.tsx index 8f5bf74b..3b52f47f 100644 --- a/frontend/src/components/Chat.tsx +++ b/frontend/src/components/Chat.tsx @@ -1,12 +1,18 @@ -import { useEffect, useRef } from "react"; +import { useEffect, useRef, useState } from "react"; import { StreamStateProps } from "../hooks/useStreamState"; import { useChatMessages } from "../hooks/useChatMessages"; import TypingBox from "./TypingBox"; import { MessageViewer } from "./Message"; -import { ArrowDownCircleIcon } from "@heroicons/react/24/outline"; +import { + ArrowDownCircleIcon, + CheckCircleIcon, +} from "@heroicons/react/24/outline"; import { MessageWithFiles } from "../utils/formTypes.ts"; import { useParams } from "react-router-dom"; import { useThreadAndAssistant } from "../hooks/useThreadAndAssistant.ts"; +import { useMessageEditing } from "../hooks/useMessageEditing.ts"; +import { MessageEditor } from "./MessageEditor.tsx"; +import { Message } from "../types.ts"; interface ChatProps extends Pick { startStream: ( @@ -24,15 +30,50 @@ function usePrevious(value: T): T | undefined { return ref.current; } +function CommitEdits(props: { + editing: Record; + commitEdits: () => Promise; +}) { + const [inflight, setInflight] = useState(false); + return ( +
+
+ {Object.keys(props.editing).length} message(s) edited. +
+ +
+ ); +} + export function Chat(props: ChatProps) { const { chatId } = useParams(); - const { messages, next } = useChatMessages( + const { messages, next, refreshMessages } = useChatMessages( chatId ?? null, props.stream, props.stopStream, ); - const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant(); + const { editing, recordEdits, commitEdits, abandonEdits } = useMessageEditing( + chatId, + refreshMessages, + ); const prevMessages = usePrevious(messages); useEffect(() => { @@ -51,17 +92,28 @@ export function Chat(props: ChatProps) { return (
- {messages?.map((msg, i) => ( - - ))} + {messages?.map((msg, i) => + editing[msg.id] ? ( + abandonEdits(msg)} + /> + ) : ( + recordEdits(msg)} + alwaysShowControls={i === messages.length - 1} + /> + ), + )} {(props.stream?.status === "inflight" || messages === null) && (
... @@ -72,37 +124,43 @@ export function Chat(props: ChatProps) { An error has occurred. Please try again.
)} - {next.length > 0 && props.stream?.status !== "inflight" && ( -
- props.startStream( - null, - currentChat.thread_id, - assistantConfig.config.configurable?.type as string, - ) - } - > - - Click to continue. -
- )} + {next.length > 0 && + props.stream?.status !== "inflight" && + Object.keys(editing).length === 0 && ( +
+ props.startStream( + null, + currentChat.thread_id, + assistantConfig.config.configurable?.type as string, + ) + } + > + + Click to continue. +
+ )}
- - props.startStream( - msg, - currentChat.thread_id, - assistantConfig.config.configurable?.type as string, - ) - } - onInterrupt={ - props.stream?.status === "inflight" ? props.stopStream : undefined - } - inflight={props.stream?.status === "inflight"} - currentConfig={assistantConfig} - currentChat={currentChat} - /> + {commitEdits && Object.keys(editing).length > 0 ? ( + + ) : ( + + props.startStream( + msg, + currentChat.thread_id, + assistantConfig.config.configurable?.type as string, + ) + } + onInterrupt={ + props.stream?.status === "inflight" ? props.stopStream : undefined + } + inflight={props.stream?.status === "inflight"} + currentConfig={assistantConfig} + currentChat={currentChat} + /> + )}
); diff --git a/frontend/src/components/JsonEditor.tsx b/frontend/src/components/JsonEditor.tsx new file mode 100644 index 00000000..526479bf --- /dev/null +++ b/frontend/src/components/JsonEditor.tsx @@ -0,0 +1,67 @@ +import CodeMirror from "@uiw/react-codemirror"; +import { json } from "@codemirror/lang-json"; +import { EditorView, keymap } from "@codemirror/view"; +import { defaultKeymap } from "@codemirror/commands"; +import { cn } from "../utils/cn"; + +export function JsonEditor(props: { + value?: string; + onChange?: (data: string) => void; + height?: string; +}) { + return ( +
+ true }, ...defaultKeymap]), + json(), + EditorView.lineWrapping, + EditorView.theme({ + "&.cm-editor": { + backgroundColor: "transparent", + transform: "translateX(-1px)", + }, + "&.cm-focused": { + outline: "none", + }, + green: { + background: "green", + }, + "& .cm-content": { + padding: "12px", + }, + "& .cm-line": { + fontFamily: "'Fira Code', monospace", + padding: 0, + overflowAnchor: "none", + fontVariantLigatures: "none", + }, + "& .cm-gutters.cm-gutters": { + backgroundColor: "transparent", + }, + "& .cm-lineNumbers .cm-gutterElement.cm-activeLineGutter": { + marginLeft: "1px", + }, + "& .cm-lineNumbers": { + minWidth: "42px", + }, + "& .cm-foldPlaceholder": { + padding: "0px 4px", + color: "hsl(var(--ls-gray-100))", + backgroundColor: "hsl(var(--divider-500))", + borderColor: "hsl(var(--divider-700))", + }, + '& .cm-gutterElement span[title="Fold line"]': { + transform: "translateY(-4px)", + display: "inline-block", + }, + }), + ]} + /> +
+ ); +} diff --git a/frontend/src/components/Message.tsx b/frontend/src/components/Message.tsx index 6ed16cc7..d091d352 100644 --- a/frontend/src/components/Message.tsx +++ b/frontend/src/components/Message.tsx @@ -1,95 +1,13 @@ import { memo, useState } from "react"; -import { MessageDocument, Message as MessageType, ToolCall } from "../types"; +import { MessageDocument, Message as MessageType } from "../types"; import { str } from "../utils/str"; import { cn } from "../utils/cn"; -import { ChevronDownIcon } from "@heroicons/react/24/outline"; +import { PencilSquareIcon } from "@heroicons/react/24/outline"; import { LangSmithActions } from "./LangSmithActions"; import { DocumentList } from "./Document"; import { omit } from "lodash"; import { StringViewer } from "./String"; - -function ToolRequest( - props: ToolCall & { - open?: boolean; - setOpen?: (open: boolean) => void; - }, -) { - return ( - <> - - Use - - {props.name && ( - - {props.name} - - )} - {props.args && ( -
-
- - - {Object.entries(props.args).map(([key, value], i) => ( - - - - - ))} - -
-
{key}
-
- {str(value)} -
-
-
- )} - - ); -} - -function ToolResponse(props: { - name?: string; - open?: boolean; - setOpen?: (open: boolean) => void; -}) { - return ( - <> - {props.name && ( - - {props.name} - - )} - {props.setOpen && ( - { - e.preventDefault(); - e.stopPropagation(); - props.setOpen?.(!props.open); - }} - > - - - )} - - ); -} +import { ToolRequest, ToolResponse } from "./Tool"; function isDocumentContent( content: MessageType["content"], @@ -132,7 +50,11 @@ export function MessageContent(props: { content: MessageType["content"] }) { : true, ); } - if (Array.isArray(content) ? content.length === 0 : !content) { + if ( + Array.isArray(content) + ? content.length === 0 + : Object.keys(content).length === 0 + ) { return null; } return
{str(content)}
; @@ -140,7 +62,11 @@ export function MessageContent(props: { content: MessageType["content"] }) { } export const MessageViewer = memo(function ( - props: MessageType & { runId?: string }, + props: MessageType & { + runId?: string; + startEditing?: () => void; + alwaysShowControls?: boolean; + }, ) { const [open, setOpen] = useState(false); const contentIsDocuments = @@ -151,15 +77,28 @@ export const MessageViewer = memo(function ( ? open : true; return ( -
+
- {props.type} +
+ {props.type} +
+ {props.startEditing && ( + + )}
{["function", "tool"].includes(props.type) && ( @@ -176,7 +115,7 @@ export const MessageViewer = memo(function (
{props.runId && ( -
+
)} diff --git a/frontend/src/components/MessageEditor.tsx b/frontend/src/components/MessageEditor.tsx new file mode 100644 index 00000000..6cffb285 --- /dev/null +++ b/frontend/src/components/MessageEditor.tsx @@ -0,0 +1,271 @@ +import { memo } from "react"; +import type { Message } from "../types"; +import { str } from "../utils/str"; +import { cn } from "../utils/cn"; +import { + XCircleIcon, + ChevronDownIcon, + PlusCircleIcon, + TrashIcon, +} from "@heroicons/react/24/outline"; +import { StringEditor } from "./StringEditor"; +import { JsonEditor } from "./JsonEditor"; + +// TODO adapt (and use) or remove +function Function(props: { + call: boolean; + name?: string; + onNameChange?: (newValue: string) => void; + argsEntries?: [string, unknown][]; + onArgsEntriesChange?: (newValue: [string, unknown][]) => void; + onRemovePressed?: () => void; + open?: boolean; + setOpen?: (open: boolean) => void; +}) { + return ( +
+
+
+ {props.call && ( + + Tool: + + )} + {props.name !== undefined && ( + props.onNameChange?.(e.target.value)} + className="rounded-md bg-gray-50 px-2 py-1 text-sm font-medium text-gray-600 ring-1 ring-inset ring-gray-500/10 -top-[1px] mr-auto focus:ring-0" + value={props.name} + /> + )} +
+ + +
+ {!props.call && props.setOpen && ( + { + e.preventDefault(); + e.stopPropagation(); + props.setOpen?.(!props.open); + }} + > + + + )} + {props.argsEntries && ( +
+ + Arguments: + +
+ + + {props.argsEntries.map(([key, value], i) => ( + + + + + ))} + + + + + + +
+ { + if (props.argsEntries !== undefined) { + props.onArgsEntriesChange?.([ + ...props.argsEntries.slice(0, i), + [e.target.value, value], + ...props.argsEntries.slice(i + 1), + ]); + } + }} + /> + +
+ { + if (props.argsEntries !== undefined) { + props.onArgsEntriesChange?.([ + ...props.argsEntries.slice(0, i), + [key, newValue], + ...props.argsEntries.slice(i + 1), + ]); + } + }} + /> + { + if (props.argsEntries !== undefined) { + props.onArgsEntriesChange?.([ + ...props.argsEntries.slice(0, i), + ...props.argsEntries.slice(i + 1), + ]); + } + }} + /> +
+
+ { + if (props.argsEntries === undefined) { + return; + } + props.onArgsEntriesChange?.([ + ...props.argsEntries, + ["", ""], + ]); + }} + /> +
+
+
+ )} +
+ ); +} + +export function ToolCallsEditor(props: { + message: Message; + onUpdate: (newValue: Message) => void; +}) { + return ( + { + try { + props.onUpdate({ + ...props.message, + tool_calls: JSON.parse(newValue), + }); + } catch (e) { + console.error(e); + } + }} + /> + ); +} + +export function MessageContentEditor(props: { + message: Message; + onUpdate: (newValue: Message) => void; +}) { + if (typeof props.message.content === "string") { + if (!props.message.content.trim()) { + return null; + } + return ( + + props.onUpdate({ + ...props.message, + content: newValue, + }) + } + className="text-gray-900 text-md leading-normal prose min-w-[65ch] bg-white" + value={props.message.content} + /> + ); + } + let content = props.message.content; + if (Array.isArray(content)) { + content = content.filter((it) => + typeof it === "object" && !!it && "type" in it + ? it.type !== "tool_use" + : true, + ); + } + if ( + Array.isArray(content) + ? content.length === 0 + : Object.keys(content).length === 0 + ) { + return null; + } + + return ( + { + try { + props.onUpdate({ + ...props.message, + content: JSON.parse(newValue), + }); + } catch (e) { + console.error(e); + } + }} + /> + ); +} + +export const MessageEditor = memo(function (props: { + message: Message; + onUpdate: (newValue: Message) => void; + abandonEdits: () => void; +}) { + const isToolRes = ["function", "tool"].includes(props.message.type); + return ( +
+
+
+
+ {props.message.type} +
+ + +
+
+ + {props.message.type === "ai" && props.message.tool_calls && ( + + )} +
+
+
+ ); +}); diff --git a/frontend/src/components/StringEditor.tsx b/frontend/src/components/StringEditor.tsx new file mode 100644 index 00000000..92f0319d --- /dev/null +++ b/frontend/src/components/StringEditor.tsx @@ -0,0 +1,49 @@ +import { cn } from "../utils/cn"; + +const COMMON_CLS = cn( + "text-sm col-[1] row-[1] m-0 resize-none overflow-hidden whitespace-pre-wrap break-words bg-transparent px-2 py-1 rounded shadow-none", +); + +export function StringEditor(props: { + value?: string | null | undefined; + placeholder?: string; + className?: string; + onChange?: (e: string) => void; + autoFocus?: boolean; + readOnly?: boolean; + cursorPointer?: boolean; + disabled?: boolean; + fullHeight?: boolean; +}) { + return ( +
+