Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edit existing messages before continuing #314

Merged
merged 9 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from typing import Annotated, Any, Dict, List, Sequence, Union
from typing import Annotated, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path
Expand Down Expand Up @@ -27,6 +26,7 @@ class ThreadPostRequest(BaseModel):
"""Payload for adding state to a thread."""

values: Union[Sequence[AnyMessage], Dict[str, Any]]
config: Optional[Dict[str, Any]] = None


@router.get("/")
Expand All @@ -41,13 +41,14 @@ async def get_thread_state(
tid: ThreadID,
):
"""Get state for a thread."""
thread, state = await asyncio.gather(
storage.get_thread(user["user_id"], tid),
storage.get_thread_state(user["user_id"], tid),
)
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return state
return await storage.get_thread_state(
user_id=user["user_id"],
thread_id=tid,
assistant_id=thread["assistant_id"],
)


@router.post("/{tid}/state")
Expand All @@ -60,7 +61,12 @@ async def add_thread_state(
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return await storage.update_thread_state(user["user_id"], tid, payload.values)
return await storage.update_thread_state(
payload.config or {"configurable": {"thread_id": tid}},
payload.values,
user_id=user["user_id"],
assistant_id=thread["assistant_id"],
)


@router.get("/{tid}/history")
Expand All @@ -69,13 +75,14 @@ async def get_thread_history(
tid: ThreadID,
):
"""Get all past states for a thread."""
thread, history = await asyncio.gather(
storage.get_thread(user["user_id"], tid),
storage.get_thread_history(user["user_id"], tid),
)
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return history
return await storage.get_thread_history(
user_id=user["user_id"],
thread_id=tid,
assistant_id=thread["assistant_id"],
)


@router.get("/{tid}")
Expand Down
24 changes: 13 additions & 11 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig

from app.agent import agent
from app.lifespan import get_pg_pool
Expand Down Expand Up @@ -98,10 +99,9 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
)


async def get_thread_state(user_id: str, thread_id: str):
async def get_thread_state(*, user_id: str, thread_id: str, assistant_id: str):
"""Get state for a thread."""
thread = await get_thread(user_id, thread_id)
assistant = await get_assistant(user_id, thread["assistant_id"])
assistant = await get_assistant(user_id, assistant_id)
state = await agent.aget_state(
{
"configurable": {
Expand All @@ -117,26 +117,28 @@ async def get_thread_state(user_id: str, thread_id: str):


async def update_thread_state(
user_id: str, thread_id: str, values: Union[Sequence[AnyMessage], Dict[str, Any]]
config: RunnableConfig,
values: Union[Sequence[AnyMessage], dict[str, Any]],
*,
user_id: str,
assistant_id: str,
):
"""Add state to a thread."""
thread = await get_thread(user_id, thread_id)
assistant = await get_assistant(user_id, thread["assistant_id"])
assistant = await get_assistant(user_id, assistant_id)
await agent.aupdate_state(
{
"configurable": {
**assistant["config"]["configurable"],
"thread_id": thread_id,
**config["configurable"],
}
},
values,
)


async def get_thread_history(user_id: str, thread_id: str):
async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str):
"""Get the history of a thread."""
thread = await get_thread(user_id, thread_id)
assistant = await get_assistant(user_id, thread["assistant_id"])
assistant = await get_assistant(user_id, assistant_id)
return [
{
"values": c.values,
Expand Down
4 changes: 2 additions & 2 deletions backend/app/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def _get_wikipedia():
@lru_cache(maxsize=1)
def _get_tavily():
tavily_search = TavilySearchAPIWrapper()
return TavilySearchResults(api_wrapper=tavily_search)
return TavilySearchResults(api_wrapper=tavily_search, name="search_tavily")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't match names used in enum



@lru_cache(maxsize=1)
def _get_tavily_answer():
tavily_search = TavilySearchAPIWrapper()
return _TavilyAnswer(api_wrapper=tavily_search)
return _TavilyAnswer(api_wrapper=tavily_search, name="search_tavily_answer")


def _get_action_server(**kwargs: ActionServerConfig):
Expand Down
4 changes: 3 additions & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"name": "frontend",
"private": true,
"version": "0.0.0",
"packageManager": "yarn@1.22.19",
"type": "module",
"scripts": {
"dev": "vite --host",
Expand All @@ -11,14 +12,15 @@
"format": "prettier -w src"
},
"dependencies": {
"@codemirror/lang-json": "^6.0.1",
"@headlessui/react": "^1.7.17",
"@heroicons/react": "^2.0.18",
"@microsoft/fetch-event-source": "^2.0.1",
"@tailwindcss/forms": "^0.5.6",
"@tailwindcss/typography": "^0.5.10",
"@uiw/react-codemirror": "^4.21.25",
"clsx": "^2.0.0",
"dompurify": "^3.0.6",
"fast-json-patch": "^3.1.1",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasn't used

"lodash": "^4.17.21",
"marked": "^9.1.5",
"react": "^18.2.0",
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function App(props: { edit?: boolean }) {
message: MessageWithFiles | null,
thread_id: string,
assistantType: string,
config?: Record<string, unknown>,
) => {
const files = message?.files || [];
if (files.length > 0) {
Expand Down Expand Up @@ -73,7 +74,7 @@ function App(props: { edit?: boolean }) {
}
}

await startStream(input, thread_id);
await startStream(input, thread_id, config);
},
[startStream],
);
Expand Down
148 changes: 103 additions & 45 deletions frontend/src/components/Chat.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import { useEffect, useRef } from "react";
import { useEffect, useRef, useState } from "react";
import { StreamStateProps } from "../hooks/useStreamState";
import { useChatMessages } from "../hooks/useChatMessages";
import TypingBox from "./TypingBox";
import { MessageViewer } from "./Message";
import { ArrowDownCircleIcon } from "@heroicons/react/24/outline";
import {
ArrowDownCircleIcon,
CheckCircleIcon,
} from "@heroicons/react/24/outline";
import { MessageWithFiles } from "../utils/formTypes.ts";
import { useParams } from "react-router-dom";
import { useThreadAndAssistant } from "../hooks/useThreadAndAssistant.ts";
import { useMessageEditing } from "../hooks/useMessageEditing.ts";
import { MessageEditor } from "./MessageEditor.tsx";
import { Message } from "../types.ts";

interface ChatProps extends Pick<StreamStateProps, "stream" | "stopStream"> {
startStream: (
Expand All @@ -24,15 +30,50 @@ function usePrevious<T>(value: T): T | undefined {
return ref.current;
}

function CommitEdits(props: {
editing: Record<string, Message>;
commitEdits: () => Promise<void>;
}) {
const [inflight, setInflight] = useState(false);
return (
<div className="bg-blue-50 text-blue-800 rounded-md ring-1 ring-inset ring-blue-800/60 flex flex-row h-9 items-center">
<div className="flex-1 rounded-l-md pl-4">
{Object.keys(props.editing).length} message(s) edited.
</div>
<button
onClick={async () => {
setInflight(true);
await props.commitEdits();
setInflight(false);
}}
className={
"self-stretch -ml-px inline-flex items-center gap-x-1.5 rounded-r-md px-3 " +
"text-sm font-semibold ring-1 ring-inset ring-blue-800/60 hover:bg-blue-100 "
}
>
<CheckCircleIcon
className="w-6 h-6 cursor-pointer stroke-2 opacity-80 hover:opacity-100 transition-opacity duration-100"
onMouseUp={props.commitEdits}
/>

{inflight ? "Saving..." : "Save"}
</button>
</div>
);
}

export function Chat(props: ChatProps) {
const { chatId } = useParams();
const { messages, next } = useChatMessages(
const { messages, next, refreshMessages } = useChatMessages(
chatId ?? null,
props.stream,
props.stopStream,
);

const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant();
const { editing, recordEdits, commitEdits, abandonEdits } = useMessageEditing(
chatId,
refreshMessages,
);

const prevMessages = usePrevious(messages);
useEffect(() => {
Expand All @@ -51,17 +92,28 @@ export function Chat(props: ChatProps) {

return (
<div className="flex-1 flex flex-col items-stretch pb-[76px] pt-2">
{messages?.map((msg, i) => (
<MessageViewer
{...msg}
key={msg.id}
runId={
i === messages.length - 1 && props.stream?.status === "done"
? props.stream?.run_id
: undefined
}
/>
))}
{messages?.map((msg, i) =>
editing[msg.id] ? (
<MessageEditor
key={msg.id}
message={editing[msg.id]}
onUpdate={recordEdits}
abandonEdits={() => abandonEdits(msg)}
/>
) : (
<MessageViewer
{...msg}
key={msg.id}
runId={
i === messages.length - 1 && props.stream?.status === "done"
? props.stream?.run_id
: undefined
}
startEditing={() => recordEdits(msg)}
alwaysShowControls={i === messages.length - 1}
/>
),
)}
{(props.stream?.status === "inflight" || messages === null) && (
<div className="leading-6 mb-2 animate-pulse font-black text-gray-400 text-lg">
...
Expand All @@ -72,37 +124,43 @@ export function Chat(props: ChatProps) {
An error has occurred. Please try again.
</div>
)}
{next.length > 0 && props.stream?.status !== "inflight" && (
<div
className="flex items-center rounded-md bg-blue-50 px-2 py-1 text-xs font-medium text-blue-800 ring-1 ring-inset ring-yellow-600/20 cursor-pointer"
onClick={() =>
props.startStream(
null,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
>
<ArrowDownCircleIcon className="h-5 w-5 mr-1" />
Click to continue.
</div>
)}
{next.length > 0 &&
props.stream?.status !== "inflight" &&
Object.keys(editing).length === 0 && (
<div
className="flex items-center rounded-md bg-blue-50 px-2 py-1 text-xs font-medium text-blue-800 ring-1 ring-inset ring-yellow-600/20 cursor-pointer"
onClick={() =>
props.startStream(
null,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
>
<ArrowDownCircleIcon className="h-5 w-5 mr-1" />
Click to continue.
</div>
)}
<div className="fixed left-0 lg:left-72 bottom-0 right-0 p-4">
<TypingBox
onSubmit={(msg) =>
props.startStream(
msg,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
onInterrupt={
props.stream?.status === "inflight" ? props.stopStream : undefined
}
inflight={props.stream?.status === "inflight"}
currentConfig={assistantConfig}
currentChat={currentChat}
/>
{commitEdits && Object.keys(editing).length > 0 ? (
<CommitEdits editing={editing} commitEdits={commitEdits} />
) : (
<TypingBox
onSubmit={(msg) =>
props.startStream(
msg,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
onInterrupt={
props.stream?.status === "inflight" ? props.stopStream : undefined
}
inflight={props.stream?.status === "inflight"}
currentConfig={assistantConfig}
currentChat={currentChat}
/>
)}
</div>
</div>
);
Expand Down
Loading
Loading