Skip to content

Commit

Permalink
Continue streaming open threads in background when switching to anoth…
Browse files Browse the repository at this point in the history
…er thread
  • Loading branch information
nfcampos committed May 13, 2024
1 parent 40a76db commit a7eadbc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 41 deletions.
13 changes: 7 additions & 6 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function App(props: { edit?: boolean }) {
const [sidebarOpen, setSidebarOpen] = useState(false);
const { chats, createChat, updateChat, deleteChat } = useChatList();
const { configs, saveConfig, deleteConfig } = useConfigList();
const { startStream, stopStream, stream } = useStreamState();
const { startStream, stopStream, streams } = useStreamState();
const { configSchema, configDefaults } = useSchemas();

const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant();
Expand Down Expand Up @@ -92,9 +92,6 @@ function App(props: { edit?: boolean }) {

const selectChat = useCallback(
async (id: string | null) => {
if (currentChat) {
stopStream?.(true);
}
if (!id) {
const firstAssistant = configs?.[0]?.assistant_id ?? null;
navigate(firstAssistant ? `/assistant/${firstAssistant}` : "/");
Expand All @@ -106,7 +103,7 @@ function App(props: { edit?: boolean }) {
setSidebarOpen(false);
}
},
[currentChat, sidebarOpen, stopStream, configs, navigate],
[sidebarOpen, configs, navigate],
);

const selectConfig = useCallback(
Expand Down Expand Up @@ -144,7 +141,11 @@ function App(props: { edit?: boolean }) {
}
>
{currentChat && assistantConfig && (
<Chat startStream={startTurn} stopStream={stopStream} stream={stream} />
<Chat
startStream={startTurn}
stopStream={stopStream?.bind(null, currentChat.thread_id)}
stream={streams[currentChat.thread_id]}
/>
)}
{currentChat && !assistantConfig && (
<OrphanChat chat={currentChat} updateChat={updateChat} />
Expand Down
6 changes: 4 additions & 2 deletions frontend/src/components/Chat.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { useEffect, useRef, useState } from "react";
import { StreamStateProps } from "../hooks/useStreamState";
import { StreamState } from "../hooks/useStreamState";
import { useChatMessages } from "../hooks/useChatMessages";
import TypingBox from "./TypingBox";
import { MessageViewer } from "./Message";
Expand All @@ -14,12 +14,14 @@ import { useMessageEditing } from "../hooks/useMessageEditing.ts";
import { MessageEditor } from "./MessageEditor.tsx";
import { Message } from "../types.ts";

interface ChatProps extends Pick<StreamStateProps, "stream" | "stopStream"> {
interface ChatProps {
startStream: (
message: MessageWithFiles | null,
thread_id: string,
assistantType: string,
) => Promise<void>;
stopStream?: (clear?: boolean) => void;
stream: StreamState;
}

function usePrevious<T>(value: T): T | undefined {
Expand Down
98 changes: 65 additions & 33 deletions frontend/src/hooks/useStreamState.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@ export interface StreamState {
}

export interface StreamStateProps {
stream: StreamState | null;
streams: {
[tid: string]: StreamState;
};
startStream: (
input: Message[] | Record<string, any> | null,
thread_id: string,
config?: Record<string, unknown>,
) => Promise<void>;
stopStream?: (clear?: boolean) => void;
stopStream?: (thread_id: string, clear?: boolean) => void;
}

export function useStreamState(): StreamStateProps {
const [current, setCurrent] = useState<StreamState | null>(null);
const [current, setCurrent] = useState<{ [tid: string]: StreamState }>({});
const [controller, setController] = useState<AbortController | null>(null);

const startStream = useCallback(
Expand All @@ -31,7 +33,10 @@ export function useStreamState(): StreamStateProps {
) => {
const controller = new AbortController();
setController(controller);
setCurrent({ status: "inflight", messages: input || [] });
setCurrent((threads) => ({
...threads,
[thread_id]: { status: "inflight", messages: input || [] },
}));

await fetchEventSource("/runs/stream", {
signal: controller.signal,
Expand All @@ -42,39 +47,60 @@ export function useStreamState(): StreamStateProps {
onmessage(msg) {
if (msg.event === "data") {
const messages = JSON.parse(msg.data);
setCurrent((current) => ({
status: "inflight" as StreamState["status"],
messages: mergeMessagesById(current?.messages, messages),
run_id: current?.run_id,
setCurrent((threads) => ({
...threads,
[thread_id]: {
status: "inflight" as StreamState["status"],
messages: mergeMessagesById(
threads[thread_id]?.messages,
messages,
),
run_id: threads[thread_id]?.run_id,
},
}));
} else if (msg.event === "metadata") {
const { run_id } = JSON.parse(msg.data);
setCurrent((current) => ({
status: "inflight",
messages: current?.messages,
run_id: run_id,
setCurrent((threads) => ({
...threads,
[thread_id]: {
status: "inflight" as StreamState["status"],
messages: threads[thread_id]?.messages,
run_id,
},
}));
} else if (msg.event === "error") {
setCurrent((current) => ({
status: "error",
messages: current?.messages,
run_id: current?.run_id,
setCurrent((threads) => ({
...threads,
[thread_id]: {
status: "error",
messages: threads[thread_id]?.messages,
run_id: threads[thread_id]?.run_id,
},
}));
}
},
onclose() {
setCurrent((current) => ({
status: current?.status === "error" ? current.status : "done",
messages: current?.messages,
run_id: current?.run_id,
setCurrent((threads) => ({
...threads,
[thread_id]: {
status:
threads[thread_id]?.status === "error"
? threads[thread_id].status
: "done",
messages: threads[thread_id]?.messages,
run_id: threads[thread_id]?.run_id,
},
}));
setController(null);
},
onerror(error) {
setCurrent((current) => ({
status: "error",
messages: current?.messages,
run_id: current?.run_id,
setCurrent((threads) => ({
...threads,
[thread_id]: {
status: "error",
messages: threads[thread_id]?.messages,
run_id: threads[thread_id]?.run_id,
},
}));
setController(null);
throw error;
Expand All @@ -85,19 +111,25 @@ export function useStreamState(): StreamStateProps {
);

const stopStream = useCallback(
(clear: boolean = false) => {
(thread_id: string, clear: boolean = false) => {
controller?.abort();
setController(null);
if (clear) {
setCurrent((current) => ({
status: "done",
run_id: current?.run_id,
setCurrent((threads) => ({
...threads,
[thread_id]: {
status: "done",
run_id: threads[thread_id]?.run_id,
},
}));
} else {
setCurrent((current) => ({
status: "done",
messages: current?.messages,
run_id: current?.run_id,
setCurrent((threads) => ({
...threads,
[thread_id]: {
status: "done",
messages: threads[thread_id]?.messages,
run_id: threads[thread_id]?.run_id,
},
}));
}
},
Expand All @@ -107,7 +139,7 @@ export function useStreamState(): StreamStateProps {
return {
startStream,
stopStream,
stream: current,
streams: current,
};
}

Expand Down

0 comments on commit a7eadbc

Please sign in to comment.