diff --git a/backend/app/api/assistants.py b/backend/app/api/assistants.py index 8458b48c..b6f5ff0c 100644 --- a/backend/app/api/assistants.py +++ b/backend/app/api/assistants.py @@ -75,3 +75,13 @@ async def upsert_assistant( config=payload.config, public=payload.public, ) + + +@router.delete("/{aid}") +async def delete_assistant( + user: AuthedUser, + aid: AssistantID, +): + """Delete an assistant by ID.""" + await storage.delete_assistant(user["user_id"], aid) + return {"status": "ok"} diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index dd6441b6..e887791a 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -44,10 +44,13 @@ async def get_thread_state( thread = await storage.get_thread(user["user_id"], tid) if not thread: raise HTTPException(status_code=404, detail="Thread not found") + assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"]) + if not assistant: + raise HTTPException(status_code=400, detail="Thread has no assistant") return await storage.get_thread_state( user_id=user["user_id"], thread_id=tid, - assistant_id=thread["assistant_id"], + assistant=assistant, ) @@ -61,11 +64,14 @@ async def add_thread_state( thread = await storage.get_thread(user["user_id"], tid) if not thread: raise HTTPException(status_code=404, detail="Thread not found") + assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"]) + if not assistant: + raise HTTPException(status_code=400, detail="Thread has no assistant") return await storage.update_thread_state( payload.config or {"configurable": {"thread_id": tid}}, payload.values, user_id=user["user_id"], - assistant_id=thread["assistant_id"], + assistant=assistant, ) @@ -78,10 +84,13 @@ async def get_thread_history( thread = await storage.get_thread(user["user_id"], tid) if not thread: raise HTTPException(status_code=404, detail="Thread not found") + assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"]) + if not assistant: + raise HTTPException(status_code=400, detail="Thread has no assistant") return await storage.get_thread_history( user_id=user["user_id"], thread_id=tid, - assistant_id=thread["assistant_id"], + assistant=assistant, ) diff --git a/backend/app/lifespan.py b/backend/app/lifespan.py index 8e15f139..0b8bb005 100644 --- a/backend/app/lifespan.py +++ b/backend/app/lifespan.py @@ -19,6 +19,12 @@ async def _init_connection(conn) -> None: decoder=orjson.loads, schema="pg_catalog", ) + await conn.set_type_codec( + "jsonb", + encoder=lambda v: orjson.dumps(v).decode(), + decoder=orjson.loads, + schema="pg_catalog", + ) await conn.set_type_codec( "uuid", encoder=lambda v: str(v), decoder=lambda v: v, schema="pg_catalog" ) diff --git a/backend/app/schema.py b/backend/app/schema.py index 4b5153a3..3ae6e595 100644 --- a/backend/app/schema.py +++ b/backend/app/schema.py @@ -41,3 +41,4 @@ class Thread(TypedDict): """The name of the thread.""" updated_at: datetime """The last time the thread was updated.""" + metadata: Optional[dict] diff --git a/backend/app/storage.py b/backend/app/storage.py index edfbc585..17b6aebc 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -76,6 +76,16 @@ async def put_assistant( } +async def delete_assistant(user_id: str, assistant_id: str) -> None: + """Delete an assistant by ID.""" + async with get_pg_pool().acquire() as conn: + await conn.execute( + "DELETE FROM assistant WHERE assistant_id = $1 AND user_id = $2", + assistant_id, + user_id, + ) + + async def list_threads(user_id: str) -> List[Thread]: """List all threads for the current user.""" async with get_pg_pool().acquire() as conn: @@ -92,15 +102,14 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]: ) -async def get_thread_state(*, user_id: str, thread_id: str, assistant_id: str): +async def get_thread_state(*, user_id: str, thread_id: str, assistant: Assistant): """Get state for a thread.""" - assistant = await get_assistant(user_id, assistant_id) state = await agent.aget_state( { "configurable": { **assistant["config"]["configurable"], "thread_id": thread_id, - "assistant_id": assistant_id, + "assistant_id": assistant["assistant_id"], } } ) @@ -115,25 +124,23 @@ async def update_thread_state( values: Union[Sequence[AnyMessage], dict[str, Any]], *, user_id: str, - assistant_id: str, + assistant: Assistant, ): """Add state to a thread.""" - assistant = await get_assistant(user_id, assistant_id) await agent.aupdate_state( { "configurable": { **assistant["config"]["configurable"], **config["configurable"], - "assistant_id": assistant_id, + "assistant_id": assistant["assistant_id"], } }, values, ) -async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str): +async def get_thread_history(*, user_id: str, thread_id: str, assistant: Assistant): """Get the history of a thread.""" - assistant = await get_assistant(user_id, assistant_id) return [ { "values": c.values, @@ -146,7 +153,7 @@ async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str) "configurable": { **assistant["config"]["configurable"], "thread_id": thread_id, - "assistant_id": assistant_id, + "assistant_id": assistant["assistant_id"], } } ) @@ -158,21 +165,29 @@ async def put_thread( ) -> Thread: """Modify a thread.""" updated_at = datetime.now(timezone.utc) + assistant = await get_assistant(user_id, assistant_id) + metadata = ( + {"assistant_type": assistant["config"]["configurable"]["type"]} + if assistant + else None + ) async with get_pg_pool().acquire() as conn: await conn.execute( ( - "INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at) VALUES ($1, $2, $3, $4, $5) " + "INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at, metadata) VALUES ($1, $2, $3, $4, $5, $6) " "ON CONFLICT (thread_id) DO UPDATE SET " "user_id = EXCLUDED.user_id," "assistant_id = EXCLUDED.assistant_id, " "name = EXCLUDED.name, " - "updated_at = EXCLUDED.updated_at;" + "updated_at = EXCLUDED.updated_at, " + "metadata = EXCLUDED.metadata;" ), thread_id, user_id, assistant_id, name, updated_at, + metadata, ) return { "thread_id": thread_id, @@ -180,9 +195,20 @@ async def put_thread( "assistant_id": assistant_id, "name": name, "updated_at": updated_at, + "metadata": metadata, } +async def delete_thread(user_id: str, thread_id: str): + """Delete a thread by ID.""" + async with get_pg_pool().acquire() as conn: + await conn.execute( + "DELETE FROM thread WHERE thread_id = $1 AND user_id = $2", + thread_id, + user_id, + ) + + async def get_or_create_user(sub: str) -> tuple[User, bool]: """Returns a tuple of the user and a boolean indicating whether the user was created.""" async with get_pg_pool().acquire() as conn: @@ -192,13 +218,3 @@ async def get_or_create_user(sub: str) -> tuple[User, bool]: 'INSERT INTO "user" (sub) VALUES ($1) RETURNING *', sub ) return user, True - - -async def delete_thread(user_id: str, thread_id: str): - """Delete a thread by ID.""" - async with get_pg_pool().acquire() as conn: - await conn.execute( - "DELETE FROM thread WHERE thread_id = $1 AND user_id = $2", - thread_id, - user_id, - ) diff --git a/backend/migrations/000004_add_metadata_to_thread.down.sql b/backend/migrations/000004_add_metadata_to_thread.down.sql new file mode 100644 index 00000000..106fd0ba --- /dev/null +++ b/backend/migrations/000004_add_metadata_to_thread.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE thread +DROP COLUMN metadata; \ No newline at end of file diff --git a/backend/migrations/000004_add_metadata_to_thread.up.sql b/backend/migrations/000004_add_metadata_to_thread.up.sql new file mode 100644 index 00000000..d0394582 --- /dev/null +++ b/backend/migrations/000004_add_metadata_to_thread.up.sql @@ -0,0 +1,9 @@ +ALTER TABLE thread +ADD COLUMN metadata JSONB; + +UPDATE thread +SET metadata = json_build_object( + 'assistant_type', (SELECT config->'configurable'->>'type' + FROM assistant + WHERE assistant.assistant_id = thread.assistant_id) +); \ No newline at end of file diff --git a/backend/tests/unit_tests/app/test_app.py b/backend/tests/unit_tests/app/test_app.py index 7a5d0a21..b84f887f 100644 --- a/backend/tests/unit_tests/app/test_app.py +++ b/backend/tests/unit_tests/app/test_app.py @@ -112,7 +112,14 @@ async def test_threads() -> None: assert response.status_code == 200 assert [ _project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json() - ] == [{"assistant_id": aid, "name": "bobby", "thread_id": tid}] + ] == [ + { + "assistant_id": aid, + "name": "bobby", + "thread_id": tid, + "metadata": {"assistant_type": "chatbot"}, + } + ] response = await client.put( f"/threads/{tid}", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 13949af8..0acc62f3 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -16,12 +16,13 @@ import { MessageWithFiles } from "./utils/formTypes.ts"; import { useNavigate } from "react-router-dom"; import { useThreadAndAssistant } from "./hooks/useThreadAndAssistant.ts"; import { Message } from "./types.ts"; +import { OrphanChat } from "./components/OrphanChat.tsx"; function App(props: { edit?: boolean }) { const navigate = useNavigate(); const [sidebarOpen, setSidebarOpen] = useState(false); - const { chats, createChat, deleteChat } = useChatList(); - const { configs, saveConfig } = useConfigList(); + const { chats, createChat, updateChat, deleteChat } = useChatList(); + const { configs, saveConfig, deleteConfig } = useConfigList(); const { startStream, stopStream, stream } = useStreamState(); const { configSchema, configDefaults } = useSchemas(); @@ -145,6 +146,9 @@ function App(props: { edit?: boolean }) { {currentChat && assistantConfig && ( )} + {currentChat && !assistantConfig && ( + + )} {!currentChat && assistantConfig && !props.edit && ( )} {!currentChat && assistantConfig && props.edit && ( diff --git a/frontend/src/api/assistants.ts b/frontend/src/api/assistants.ts index 74402629..aa5a175a 100644 --- a/frontend/src/api/assistants.ts +++ b/frontend/src/api/assistants.ts @@ -14,3 +14,16 @@ export async function getAssistant( return null; } } + +export async function getAssistants(): Promise { + try { + const response = await fetch(`/assistants/`); + if (!response.ok) { + return null; + } + return (await response.json()) as Config[]; + } catch (error) { + console.error("Failed to fetch assistants:", error); + return null; + } +} diff --git a/frontend/src/components/ChatList.tsx b/frontend/src/components/ChatList.tsx index 3e79e94d..470e1668 100644 --- a/frontend/src/components/ChatList.tsx +++ b/frontend/src/components/ChatList.tsx @@ -139,7 +139,13 @@ export function ChatList(props: { role="menuitem" onClick={(event) => { event.preventDefault(); - props.deleteChat(chat.thread_id); + if ( + window.confirm( + `Are you sure you want to delete chat "${chat.name}"?`, + ) + ) { + props.deleteChat(chat.thread_id); + } }} > Delete diff --git a/frontend/src/components/ConfigList.tsx b/frontend/src/components/ConfigList.tsx index f93eeb7c..8f3f607a 100644 --- a/frontend/src/components/ConfigList.tsx +++ b/frontend/src/components/ConfigList.tsx @@ -1,13 +1,14 @@ import { TYPES } from "../constants"; import { Config, ConfigListProps } from "../hooks/useConfigList"; import { cn } from "../utils/cn"; -import { PencilSquareIcon } from "@heroicons/react/24/outline"; +import { PencilSquareIcon, TrashIcon } from "@heroicons/react/24/outline"; import { Link } from "react-router-dom"; function ConfigItem(props: { config: Config; currentConfig: Config | null; enterConfig: (id: string | null) => void; + deleteConfig: (id: string) => void; }) { return (
  • @@ -50,6 +51,22 @@ function ConfigItem(props: { > + { + event.preventDefault(); + if ( + window.confirm( + `Are you sure you want to delete bot "${props.config.name}?"`, + ) + ) { + props.deleteConfig(props.config.assistant_id); + } + }} + > + +
  • ); @@ -59,6 +76,7 @@ export function ConfigList(props: { configs: ConfigListProps["configs"]; currentConfig: Config | null; enterConfig: (id: string | null) => void; + deleteConfig: (id: string) => void; }) { return ( <> @@ -74,6 +92,7 @@ export function ConfigList(props: { config={assistant} currentConfig={props.currentConfig} enterConfig={props.enterConfig} + deleteConfig={props.deleteConfig} /> )) ?? (
  • @@ -94,6 +113,7 @@ export function ConfigList(props: { config={assistant} currentConfig={props.currentConfig} enterConfig={props.enterConfig} + deleteConfig={props.deleteConfig} /> )) ?? (
  • diff --git a/frontend/src/components/NewChat.tsx b/frontend/src/components/NewChat.tsx index 4d22e19a..58c284cc 100644 --- a/frontend/src/components/NewChat.tsx +++ b/frontend/src/components/NewChat.tsx @@ -15,6 +15,7 @@ interface NewChatProps extends ConfigListProps { configSchema: Schemas["configSchema"]; configDefaults: Schemas["configDefaults"]; enterConfig: (id: string | null) => void; + deleteConfig: (id: string) => Promise; startChat: ( config: ConfigInterface, message: MessageWithFiles, @@ -39,11 +40,12 @@ export function NewChat(props: NewChatProps) { )} >
    -
    +
    navigator(`/assistant/${id}`)} + deleteConfig={props.deleteConfig} />
    diff --git a/frontend/src/components/OrphanChat.tsx b/frontend/src/components/OrphanChat.tsx new file mode 100644 index 00000000..1369e90e --- /dev/null +++ b/frontend/src/components/OrphanChat.tsx @@ -0,0 +1,110 @@ +import { useEffect, useState } from "react"; +import { Config } from "../hooks/useConfigList"; +import { Chat } from "../types"; +import { getAssistants } from "../api/assistants"; +import { useThreadAndAssistant } from "../hooks/useThreadAndAssistant"; + +export function OrphanChat(props: { + chat: Chat; + updateChat: ( + name: string, + thread_id: string, + assistant_id: string | null, + ) => Promise; +}) { + const [newConfigId, setNewConfigId] = useState(null as string | null); + const [configs, setConfigs] = useState([]); + const { invalidateChat } = useThreadAndAssistant(); + + const update = async () => { + if (!newConfigId) { + alert("Please select a bot."); + return; + } + const updatedChat = await props.updateChat( + props.chat.thread_id, + props.chat.name, + newConfigId, + ); + invalidateChat(updatedChat.thread_id); + }; + + const botTypeToName = (botType: string) => { + switch (botType) { + case "chatbot": + return "Chatbot"; + case "chat_retrieval": + return "RAG"; + case "agent": + return "Assistant"; + default: + return botType; + } + }; + + useEffect(() => { + async function fetchConfigs() { + const configs = await getAssistants(); + const suitableConfigs = configs + ? configs.filter( + (config) => + config.config.configurable?.type === + props.chat.metadata?.assistant_type, + ) + : []; + setConfigs(suitableConfigs); + } + + fetchConfigs(); + }, [props.chat.metadata?.assistant_type]); + + return ( +
    + {configs.length ? ( +
    { + e.preventDefault(); + await update(); + }} + className="space-y-4 max-w-xl w-full px-4" + > +
    + This chat has no bot attached. To continue chatting, please attach a + bot. +
    +
    +
    + +
    + +
    +
    + ) : ( +
    +
    + This chat has no bot attached. To continue chatting, you need to + attach a bot. However, there are no suitable bots available for this + chat. Please create a new bot with type{" "} + {botTypeToName(props.chat.metadata?.assistant_type as string)} and + try again. +
    +
    + )} +
    + ); +} diff --git a/frontend/src/hooks/useChatList.ts b/frontend/src/hooks/useChatList.ts index b5a51bb6..5778b5e4 100644 --- a/frontend/src/hooks/useChatList.ts +++ b/frontend/src/hooks/useChatList.ts @@ -4,10 +4,11 @@ import { Chat } from "../types"; export interface ChatListProps { chats: Chat[] | null; - createChat: ( + createChat: (name: string, assistant_id: string) => Promise; + updateChat: ( name: string, - assistant_id: string, - thread_id?: string, + thread_id: string, + assistant_id: string | null, ) => Promise; deleteChat: (thread_id: string) => Promise; } @@ -57,6 +58,23 @@ export function useChatList(): ChatListProps { return saved; }, []); + const updateChat = useCallback( + async (thread_id: string, name: string, assistant_id: string | null) => { + const response = await fetch(`/threads/${thread_id}`, { + method: "PUT", + body: JSON.stringify({ assistant_id, name }), + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, + }); + const saved = await response.json(); + setChats(saved); + return saved; + }, + [], + ); + const deleteChat = useCallback( async (thread_id: string) => { await fetch(`/threads/${thread_id}`, { @@ -73,6 +91,7 @@ export function useChatList(): ChatListProps { return { chats, createChat, + updateChat, deleteChat, }; } diff --git a/frontend/src/hooks/useConfigList.ts b/frontend/src/hooks/useConfigList.ts index bd50ddf9..91014f05 100644 --- a/frontend/src/hooks/useConfigList.ts +++ b/frontend/src/hooks/useConfigList.ts @@ -1,5 +1,6 @@ import { useCallback, useEffect, useReducer } from "react"; import orderBy from "lodash/orderBy"; +import { getAssistants } from "../api/assistants"; export interface Config { assistant_id: string; @@ -29,6 +30,7 @@ export interface ConfigListProps { isPublic: boolean, assistantId?: string, ) => Promise; + deleteConfig: (assistantId: string) => Promise; } function configsReducer( @@ -51,14 +53,10 @@ export function useConfigList(): ConfigListProps { useEffect(() => { async function fetchConfigs() { - const myConfigs = await fetch("/assistants/", { - headers: { - Accept: "application/json", - }, - }) - .then((r) => r.json()) - .then((li) => li.map((c: Config) => ({ ...c, mine: true }))); - setConfigs(myConfigs); + const assistants = await getAssistants(); + setConfigs( + assistants ? assistants.map((c) => ({ ...c, mine: true })) : [], + ); } fetchConfigs(); @@ -105,8 +103,22 @@ export function useConfigList(): ConfigListProps { [], ); + const deleteConfig = useCallback( + async (assistantId: string): Promise => { + await fetch(`/assistants/${assistantId}`, { + method: "DELETE", + headers: { + Accept: "application/json", + }, + }); + setConfigs((configs || []).filter((c) => c.assistant_id !== assistantId)); + }, + [configs], + ); + return { configs, saveConfig, + deleteConfig, }; } diff --git a/frontend/src/hooks/useThreadAndAssistant.ts b/frontend/src/hooks/useThreadAndAssistant.ts index 4b48fd9e..7b17c868 100644 --- a/frontend/src/hooks/useThreadAndAssistant.ts +++ b/frontend/src/hooks/useThreadAndAssistant.ts @@ -1,4 +1,4 @@ -import { useQuery } from "react-query"; +import { useQuery, useQueryClient } from "react-query"; import { useParams } from "react-router-dom"; import { getAssistant } from "../api/assistants"; import { getThread } from "../api/threads"; @@ -6,6 +6,7 @@ import { getThread } from "../api/threads"; export function useThreadAndAssistant() { // Extract route parameters const { chatId, assistantId } = useParams(); + const queryClient = useQueryClient(); // React Query to fetch chat details if chatId is present const { data: currentChat, isLoading: isLoadingChat } = useQuery( @@ -28,10 +29,15 @@ export function useThreadAndAssistant() { }, ); + const invalidateChat = (chatId: string) => { + queryClient.invalidateQueries(["thread", chatId]); + }; + // Return both loading states, the chat data, and the assistant configuration return { currentChat, assistantConfig, isLoading: isLoadingChat || isLoadingAssistant, + invalidateChat, }; } diff --git a/frontend/src/types.ts b/frontend/src/types.ts index af698847..f95e0444 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -24,4 +24,5 @@ export interface Chat { thread_id: string; name: string; updated_at: string; + metadata: Record | null; }