Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deleting assistants & attaching orphaned threads to new assistants #328

Merged
merged 13 commits into from
May 8, 2024
Merged
10 changes: 10 additions & 0 deletions backend/app/api/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,13 @@ async def upsert_assistant(
config=payload.config,
public=payload.public,
)


@router.delete("/{aid}")
async def delete_assistant(
user: AuthedUser,
aid: AssistantID,
):
"""Delete an assistant by ID."""
await storage.delete_assistant(user["user_id"], aid)
return {"status": "ok"}
15 changes: 12 additions & 3 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ async def get_thread_state(
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
if not assistant:
raise HTTPException(status_code=400, detail="Thread has no assistant")
return await storage.get_thread_state(
user_id=user["user_id"],
thread_id=tid,
assistant_id=thread["assistant_id"],
assistant=assistant,
)


Expand All @@ -61,11 +64,14 @@ async def add_thread_state(
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
if not assistant:
raise HTTPException(status_code=400, detail="Thread has no assistant")
return await storage.update_thread_state(
payload.config or {"configurable": {"thread_id": tid}},
payload.values,
user_id=user["user_id"],
assistant_id=thread["assistant_id"],
assistant=assistant,
)


Expand All @@ -78,10 +84,13 @@ async def get_thread_history(
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
if not assistant:
raise HTTPException(status_code=400, detail="Thread has no assistant")
return await storage.get_thread_history(
user_id=user["user_id"],
thread_id=tid,
assistant_id=thread["assistant_id"],
assistant=assistant,
)


Expand Down
6 changes: 6 additions & 0 deletions backend/app/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ async def _init_connection(conn) -> None:
decoder=orjson.loads,
schema="pg_catalog",
)
await conn.set_type_codec(
"jsonb",
encoder=lambda v: orjson.dumps(v).decode(),
decoder=orjson.loads,
schema="pg_catalog",
)
await conn.set_type_codec(
"uuid", encoder=lambda v: str(v), decoder=lambda v: v, schema="pg_catalog"
)
Expand Down
1 change: 1 addition & 0 deletions backend/app/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ class Thread(TypedDict):
"""The name of the thread."""
updated_at: datetime
"""The last time the thread was updated."""
metadata: Optional[dict]
58 changes: 37 additions & 21 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ async def put_assistant(
}


async def delete_assistant(user_id: str, assistant_id: str) -> None:
"""Delete an assistant by ID."""
async with get_pg_pool().acquire() as conn:
await conn.execute(
"DELETE FROM assistant WHERE assistant_id = $1 AND user_id = $2",
assistant_id,
user_id,
)


async def list_threads(user_id: str) -> List[Thread]:
"""List all threads for the current user."""
async with get_pg_pool().acquire() as conn:
Expand All @@ -92,15 +102,14 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
)


async def get_thread_state(*, user_id: str, thread_id: str, assistant_id: str):
async def get_thread_state(*, user_id: str, thread_id: str, assistant: Assistant):
"""Get state for a thread."""
assistant = await get_assistant(user_id, assistant_id)
state = await agent.aget_state(
{
"configurable": {
**assistant["config"]["configurable"],
"thread_id": thread_id,
"assistant_id": assistant_id,
"assistant_id": assistant["assistant_id"],
}
}
)
Expand All @@ -115,25 +124,23 @@ async def update_thread_state(
values: Union[Sequence[AnyMessage], dict[str, Any]],
*,
user_id: str,
assistant_id: str,
assistant: Assistant,
):
"""Add state to a thread."""
assistant = await get_assistant(user_id, assistant_id)
await agent.aupdate_state(
{
"configurable": {
**assistant["config"]["configurable"],
**config["configurable"],
"assistant_id": assistant_id,
"assistant_id": assistant["assistant_id"],
}
},
values,
)


async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str):
async def get_thread_history(*, user_id: str, thread_id: str, assistant: Assistant):
"""Get the history of a thread."""
assistant = await get_assistant(user_id, assistant_id)
return [
{
"values": c.values,
Expand All @@ -146,7 +153,7 @@ async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str)
"configurable": {
**assistant["config"]["configurable"],
"thread_id": thread_id,
"assistant_id": assistant_id,
"assistant_id": assistant["assistant_id"],
}
}
)
Expand All @@ -158,31 +165,50 @@ async def put_thread(
) -> Thread:
"""Modify a thread."""
updated_at = datetime.now(timezone.utc)
assistant = await get_assistant(user_id, assistant_id)
metadata = (
{"assistant_type": assistant["config"]["configurable"]["type"]}
if assistant
else None
)
async with get_pg_pool().acquire() as conn:
await conn.execute(
(
"INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at) VALUES ($1, $2, $3, $4, $5) "
"INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at, metadata) VALUES ($1, $2, $3, $4, $5, $6) "
"ON CONFLICT (thread_id) DO UPDATE SET "
"user_id = EXCLUDED.user_id,"
"assistant_id = EXCLUDED.assistant_id, "
"name = EXCLUDED.name, "
"updated_at = EXCLUDED.updated_at;"
"updated_at = EXCLUDED.updated_at, "
"metadata = EXCLUDED.metadata;"
),
thread_id,
user_id,
assistant_id,
name,
updated_at,
metadata,
)
return {
"thread_id": thread_id,
"user_id": user_id,
"assistant_id": assistant_id,
"name": name,
"updated_at": updated_at,
"metadata": metadata,
}


async def delete_thread(user_id: str, thread_id: str):
"""Delete a thread by ID."""
async with get_pg_pool().acquire() as conn:
await conn.execute(
"DELETE FROM thread WHERE thread_id = $1 AND user_id = $2",
thread_id,
user_id,
)


async def get_or_create_user(sub: str) -> tuple[User, bool]:
"""Returns a tuple of the user and a boolean indicating whether the user was created."""
async with get_pg_pool().acquire() as conn:
Expand All @@ -192,13 +218,3 @@ async def get_or_create_user(sub: str) -> tuple[User, bool]:
'INSERT INTO "user" (sub) VALUES ($1) RETURNING *', sub
)
return user, True


async def delete_thread(user_id: str, thread_id: str):
"""Delete a thread by ID."""
async with get_pg_pool().acquire() as conn:
await conn.execute(
"DELETE FROM thread WHERE thread_id = $1 AND user_id = $2",
thread_id,
user_id,
)
2 changes: 2 additions & 0 deletions backend/migrations/000004_add_metadata_to_thread.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE thread
DROP COLUMN metadata;
9 changes: 9 additions & 0 deletions backend/migrations/000004_add_metadata_to_thread.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
ALTER TABLE thread
ADD COLUMN metadata JSONB;

UPDATE thread
SET metadata = json_build_object(
'assistant_type', (SELECT config->'configurable'->>'type'
FROM assistant
WHERE assistant.assistant_id = thread.assistant_id)
);
9 changes: 8 additions & 1 deletion backend/tests/unit_tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ async def test_threads() -> None:
assert response.status_code == 200
assert [
_project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json()
] == [{"assistant_id": aid, "name": "bobby", "thread_id": tid}]
] == [
{
"assistant_id": aid,
"name": "bobby",
"thread_id": tid,
"metadata": {"assistant_type": "chatbot"},
}
]

response = await client.put(
f"/threads/{tid}",
Expand Down
9 changes: 7 additions & 2 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ import { MessageWithFiles } from "./utils/formTypes.ts";
import { useNavigate } from "react-router-dom";
import { useThreadAndAssistant } from "./hooks/useThreadAndAssistant.ts";
import { Message } from "./types.ts";
import { OrphanChat } from "./components/OrphanChat.tsx";

function App(props: { edit?: boolean }) {
const navigate = useNavigate();
const [sidebarOpen, setSidebarOpen] = useState(false);
const { chats, createChat, deleteChat } = useChatList();
const { configs, saveConfig } = useConfigList();
const { chats, createChat, updateChat, deleteChat } = useChatList();
const { configs, saveConfig, deleteConfig } = useConfigList();
const { startStream, stopStream, stream } = useStreamState();
const { configSchema, configDefaults } = useSchemas();

Expand Down Expand Up @@ -145,6 +146,9 @@ function App(props: { edit?: boolean }) {
{currentChat && assistantConfig && (
<Chat startStream={startTurn} stopStream={stopStream} stream={stream} />
)}
{currentChat && !assistantConfig && (
<OrphanChat chat={currentChat} updateChat={updateChat} />
)}
{!currentChat && assistantConfig && !props.edit && (
<NewChat
startChat={startChat}
Expand All @@ -153,6 +157,7 @@ function App(props: { edit?: boolean }) {
configs={configs}
saveConfig={saveConfig}
enterConfig={selectConfig}
deleteConfig={deleteConfig}
/>
)}
{!currentChat && assistantConfig && props.edit && (
Expand Down
13 changes: 13 additions & 0 deletions frontend/src/api/assistants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,16 @@ export async function getAssistant(
return null;
}
}

export async function getAssistants(): Promise<Config[] | null> {
try {
const response = await fetch(`/assistants/`);
if (!response.ok) {
return null;
}
return (await response.json()) as Config[];
} catch (error) {
console.error("Failed to fetch assistants:", error);
return null;
}
}
8 changes: 7 additions & 1 deletion frontend/src/components/ChatList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,13 @@ export function ChatList(props: {
role="menuitem"
onClick={(event) => {
event.preventDefault();
props.deleteChat(chat.thread_id);
if (
window.confirm(
`Are you sure you want to delete chat "${chat.name}"?`,
)
) {
props.deleteChat(chat.thread_id);
}
}}
>
Delete
Expand Down
22 changes: 21 additions & 1 deletion frontend/src/components/ConfigList.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { TYPES } from "../constants";
import { Config, ConfigListProps } from "../hooks/useConfigList";
import { cn } from "../utils/cn";
import { PencilSquareIcon } from "@heroicons/react/24/outline";
import { PencilSquareIcon, TrashIcon } from "@heroicons/react/24/outline";
import { Link } from "react-router-dom";

function ConfigItem(props: {
config: Config;
currentConfig: Config | null;
enterConfig: (id: string | null) => void;
deleteConfig: (id: string) => void;
}) {
return (
<li key={props.config.assistant_id}>
Expand Down Expand Up @@ -50,6 +51,22 @@ function ConfigItem(props: {
>
<PencilSquareIcon />
</Link>
<Link
className="w-5"
to="#"
onClick={(event) => {
event.preventDefault();
if (
window.confirm(
`Are you sure you want to delete bot "${props.config.name}?"`,
)
) {
props.deleteConfig(props.config.assistant_id);
}
}}
>
<TrashIcon />
</Link>
</div>
</li>
);
Expand All @@ -59,6 +76,7 @@ export function ConfigList(props: {
configs: ConfigListProps["configs"];
currentConfig: Config | null;
enterConfig: (id: string | null) => void;
deleteConfig: (id: string) => void;
}) {
return (
<>
Expand All @@ -74,6 +92,7 @@ export function ConfigList(props: {
config={assistant}
currentConfig={props.currentConfig}
enterConfig={props.enterConfig}
deleteConfig={props.deleteConfig}
/>
)) ?? (
<li className="leading-6 p-2 animate-pulse font-black text-gray-400 text-lg">
Expand All @@ -94,6 +113,7 @@ export function ConfigList(props: {
config={assistant}
currentConfig={props.currentConfig}
enterConfig={props.enterConfig}
deleteConfig={props.deleteConfig}
/>
)) ?? (
<li className="leading-6 p-2 animate-pulse font-black text-gray-400 text-lg">
Expand Down
4 changes: 3 additions & 1 deletion frontend/src/components/NewChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ interface NewChatProps extends ConfigListProps {
configSchema: Schemas["configSchema"];
configDefaults: Schemas["configDefaults"];
enterConfig: (id: string | null) => void;
deleteConfig: (id: string) => Promise<void>;
startChat: (
config: ConfigInterface,
message: MessageWithFiles,
Expand All @@ -39,11 +40,12 @@ export function NewChat(props: NewChatProps) {
)}
>
<div className="flex-1 flex flex-col md:flex-row lg:items-stretch self-stretch">
<div className="w-72 border-r border-gray-200 pr-6">
<div className="md:w-72 border-r border-gray-200 pr-6">
<ConfigList
configs={props.configs}
currentConfig={assistantConfig}
enterConfig={(id) => navigator(`/assistant/${id}`)}
deleteConfig={props.deleteConfig}
/>
</div>

Expand Down
Loading
Loading