Skip to content

Commit

Permalink
Merge pull request #314 from langchain-ai/nc/18apr/edit-messages
Browse files Browse the repository at this point in the history
Edit existing messages before continuing
  • Loading branch information
nfcampos authored Apr 18, 2024
2 parents 049aafd + dbc8594 commit c4590ec
Show file tree
Hide file tree
Showing 16 changed files with 924 additions and 191 deletions.
33 changes: 20 additions & 13 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from typing import Annotated, Any, Dict, List, Sequence, Union
from typing import Annotated, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path
Expand Down Expand Up @@ -27,6 +26,7 @@ class ThreadPostRequest(BaseModel):
"""Payload for adding state to a thread."""

values: Union[Sequence[AnyMessage], Dict[str, Any]]
config: Optional[Dict[str, Any]] = None


@router.get("/")
Expand All @@ -41,13 +41,14 @@ async def get_thread_state(
tid: ThreadID,
):
"""Get state for a thread."""
thread, state = await asyncio.gather(
storage.get_thread(user["user_id"], tid),
storage.get_thread_state(user["user_id"], tid),
)
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return state
return await storage.get_thread_state(
user_id=user["user_id"],
thread_id=tid,
assistant_id=thread["assistant_id"],
)


@router.post("/{tid}/state")
Expand All @@ -60,7 +61,12 @@ async def add_thread_state(
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return await storage.update_thread_state(user["user_id"], tid, payload.values)
return await storage.update_thread_state(
payload.config or {"configurable": {"thread_id": tid}},
payload.values,
user_id=user["user_id"],
assistant_id=thread["assistant_id"],
)


@router.get("/{tid}/history")
Expand All @@ -69,13 +75,14 @@ async def get_thread_history(
tid: ThreadID,
):
"""Get all past states for a thread."""
thread, history = await asyncio.gather(
storage.get_thread(user["user_id"], tid),
storage.get_thread_history(user["user_id"], tid),
)
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return history
return await storage.get_thread_history(
user_id=user["user_id"],
thread_id=tid,
assistant_id=thread["assistant_id"],
)


@router.get("/{tid}")
Expand Down
24 changes: 13 additions & 11 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig

from app.agent import agent
from app.lifespan import get_pg_pool
Expand Down Expand Up @@ -98,10 +99,9 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
)


async def get_thread_state(user_id: str, thread_id: str):
async def get_thread_state(*, user_id: str, thread_id: str, assistant_id: str):
"""Get state for a thread."""
thread = await get_thread(user_id, thread_id)
assistant = await get_assistant(user_id, thread["assistant_id"])
assistant = await get_assistant(user_id, assistant_id)
state = await agent.aget_state(
{
"configurable": {
Expand All @@ -117,26 +117,28 @@ async def get_thread_state(user_id: str, thread_id: str):


async def update_thread_state(
user_id: str, thread_id: str, values: Union[Sequence[AnyMessage], Dict[str, Any]]
config: RunnableConfig,
values: Union[Sequence[AnyMessage], dict[str, Any]],
*,
user_id: str,
assistant_id: str,
):
"""Add state to a thread."""
thread = await get_thread(user_id, thread_id)
assistant = await get_assistant(user_id, thread["assistant_id"])
assistant = await get_assistant(user_id, assistant_id)
await agent.aupdate_state(
{
"configurable": {
**assistant["config"]["configurable"],
"thread_id": thread_id,
**config["configurable"],
}
},
values,
)


async def get_thread_history(user_id: str, thread_id: str):
async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str):
"""Get the history of a thread."""
thread = await get_thread(user_id, thread_id)
assistant = await get_assistant(user_id, thread["assistant_id"])
assistant = await get_assistant(user_id, assistant_id)
return [
{
"values": c.values,
Expand Down
4 changes: 2 additions & 2 deletions backend/app/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def _get_wikipedia():
@lru_cache(maxsize=1)
def _get_tavily():
tavily_search = TavilySearchAPIWrapper()
return TavilySearchResults(api_wrapper=tavily_search)
return TavilySearchResults(api_wrapper=tavily_search, name="search_tavily")


@lru_cache(maxsize=1)
def _get_tavily_answer():
tavily_search = TavilySearchAPIWrapper()
return _TavilyAnswer(api_wrapper=tavily_search)
return _TavilyAnswer(api_wrapper=tavily_search, name="search_tavily_answer")


def _get_action_server(**kwargs: ActionServerConfig):
Expand Down
4 changes: 3 additions & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"name": "frontend",
"private": true,
"version": "0.0.0",
"packageManager": "yarn@1.22.19",
"type": "module",
"scripts": {
"dev": "vite --host",
Expand All @@ -11,14 +12,15 @@
"format": "prettier -w src"
},
"dependencies": {
"@codemirror/lang-json": "^6.0.1",
"@headlessui/react": "^1.7.17",
"@heroicons/react": "^2.0.18",
"@microsoft/fetch-event-source": "^2.0.1",
"@tailwindcss/forms": "^0.5.6",
"@tailwindcss/typography": "^0.5.10",
"@uiw/react-codemirror": "^4.21.25",
"clsx": "^2.0.0",
"dompurify": "^3.0.6",
"fast-json-patch": "^3.1.1",
"lodash": "^4.17.21",
"marked": "^9.1.5",
"react": "^18.2.0",
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function App(props: { edit?: boolean }) {
message: MessageWithFiles | null,
thread_id: string,
assistantType: string,
config?: Record<string, unknown>,
) => {
const files = message?.files || [];
if (files.length > 0) {
Expand Down Expand Up @@ -73,7 +74,7 @@ function App(props: { edit?: boolean }) {
}
}

await startStream(input, thread_id);
await startStream(input, thread_id, config);
},
[startStream],
);
Expand Down
148 changes: 103 additions & 45 deletions frontend/src/components/Chat.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import { useEffect, useRef } from "react";
import { useEffect, useRef, useState } from "react";
import { StreamStateProps } from "../hooks/useStreamState";
import { useChatMessages } from "../hooks/useChatMessages";
import TypingBox from "./TypingBox";
import { MessageViewer } from "./Message";
import { ArrowDownCircleIcon } from "@heroicons/react/24/outline";
import {
ArrowDownCircleIcon,
CheckCircleIcon,
} from "@heroicons/react/24/outline";
import { MessageWithFiles } from "../utils/formTypes.ts";
import { useParams } from "react-router-dom";
import { useThreadAndAssistant } from "../hooks/useThreadAndAssistant.ts";
import { useMessageEditing } from "../hooks/useMessageEditing.ts";
import { MessageEditor } from "./MessageEditor.tsx";
import { Message } from "../types.ts";

interface ChatProps extends Pick<StreamStateProps, "stream" | "stopStream"> {
startStream: (
Expand All @@ -24,15 +30,50 @@ function usePrevious<T>(value: T): T | undefined {
return ref.current;
}

function CommitEdits(props: {
editing: Record<string, Message>;
commitEdits: () => Promise<void>;
}) {
const [inflight, setInflight] = useState(false);
return (
<div className="bg-blue-50 text-blue-800 rounded-md ring-1 ring-inset ring-blue-800/60 flex flex-row h-9 items-center">
<div className="flex-1 rounded-l-md pl-4">
{Object.keys(props.editing).length} message(s) edited.
</div>
<button
onClick={async () => {
setInflight(true);
await props.commitEdits();
setInflight(false);
}}
className={
"self-stretch -ml-px inline-flex items-center gap-x-1.5 rounded-r-md px-3 " +
"text-sm font-semibold ring-1 ring-inset ring-blue-800/60 hover:bg-blue-100 "
}
>
<CheckCircleIcon
className="w-6 h-6 cursor-pointer stroke-2 opacity-80 hover:opacity-100 transition-opacity duration-100"
onMouseUp={props.commitEdits}
/>

{inflight ? "Saving..." : "Save"}
</button>
</div>
);
}

export function Chat(props: ChatProps) {
const { chatId } = useParams();
const { messages, next } = useChatMessages(
const { messages, next, refreshMessages } = useChatMessages(
chatId ?? null,
props.stream,
props.stopStream,
);

const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant();
const { editing, recordEdits, commitEdits, abandonEdits } = useMessageEditing(
chatId,
refreshMessages,
);

const prevMessages = usePrevious(messages);
useEffect(() => {
Expand All @@ -51,17 +92,28 @@ export function Chat(props: ChatProps) {

return (
<div className="flex-1 flex flex-col items-stretch pb-[76px] pt-2">
{messages?.map((msg, i) => (
<MessageViewer
{...msg}
key={msg.id}
runId={
i === messages.length - 1 && props.stream?.status === "done"
? props.stream?.run_id
: undefined
}
/>
))}
{messages?.map((msg, i) =>
editing[msg.id] ? (
<MessageEditor
key={msg.id}
message={editing[msg.id]}
onUpdate={recordEdits}
abandonEdits={() => abandonEdits(msg)}
/>
) : (
<MessageViewer
{...msg}
key={msg.id}
runId={
i === messages.length - 1 && props.stream?.status === "done"
? props.stream?.run_id
: undefined
}
startEditing={() => recordEdits(msg)}
alwaysShowControls={i === messages.length - 1}
/>
),
)}
{(props.stream?.status === "inflight" || messages === null) && (
<div className="leading-6 mb-2 animate-pulse font-black text-gray-400 text-lg">
...
Expand All @@ -72,37 +124,43 @@ export function Chat(props: ChatProps) {
An error has occurred. Please try again.
</div>
)}
{next.length > 0 && props.stream?.status !== "inflight" && (
<div
className="flex items-center rounded-md bg-blue-50 px-2 py-1 text-xs font-medium text-blue-800 ring-1 ring-inset ring-yellow-600/20 cursor-pointer"
onClick={() =>
props.startStream(
null,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
>
<ArrowDownCircleIcon className="h-5 w-5 mr-1" />
Click to continue.
</div>
)}
{next.length > 0 &&
props.stream?.status !== "inflight" &&
Object.keys(editing).length === 0 && (
<div
className="flex items-center rounded-md bg-blue-50 px-2 py-1 text-xs font-medium text-blue-800 ring-1 ring-inset ring-yellow-600/20 cursor-pointer"
onClick={() =>
props.startStream(
null,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
>
<ArrowDownCircleIcon className="h-5 w-5 mr-1" />
Click to continue.
</div>
)}
<div className="fixed left-0 lg:left-72 bottom-0 right-0 p-4">
<TypingBox
onSubmit={(msg) =>
props.startStream(
msg,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
onInterrupt={
props.stream?.status === "inflight" ? props.stopStream : undefined
}
inflight={props.stream?.status === "inflight"}
currentConfig={assistantConfig}
currentChat={currentChat}
/>
{commitEdits && Object.keys(editing).length > 0 ? (
<CommitEdits editing={editing} commitEdits={commitEdits} />
) : (
<TypingBox
onSubmit={(msg) =>
props.startStream(
msg,
currentChat.thread_id,
assistantConfig.config.configurable?.type as string,
)
}
onInterrupt={
props.stream?.status === "inflight" ? props.stopStream : undefined
}
inflight={props.stream?.status === "inflight"}
currentConfig={assistantConfig}
currentChat={currentChat}
/>
)}
</div>
</div>
);
Expand Down
Loading

0 comments on commit c4590ec

Please sign in to comment.