Skip to content

Commit

Permalink
feature (PersonaFlow-148): Tool streaming (#168)
Browse files Browse the repository at this point in the history
* committed new poetry lock file

* removed userID when fetching from useChat to use default user

* made retrieval_description optional

* added tools folder

* added TToolCall and mapped each tool call to a ToolQuery

* render tool_calls in MessagesContainer

* updated styling for tools

* adjusted border on build-panel

* added tool-result

* added link to url

* remove transition

* removed underline in accordion, added link to visit url

* addressed pr comments
  • Loading branch information
krishokr authored Jul 13, 2024
1 parent 06bb60e commit d504908
Show file tree
Hide file tree
Showing 11 changed files with 729 additions and 531 deletions.
980 changes: 493 additions & 487 deletions poetry.lock

Large diffs are not rendered by default.

28 changes: 15 additions & 13 deletions ui/src/components/features/build-panel/components/build-panel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,25 @@ export default function BuildPanel() {
const [isOpen, setIsOpen] = useState(true);

const drawerStyles = {
open: "p-4 border-solid border-2 h-full flex flex-col gap-4 overflow-x-hidden sm:min-w-[520px]",
open: "p-4 h-full flex flex-col gap-4 overflow-x-hidden sm:min-w-[520px]",
closed: "hidden",
};

return (
<div className="flex items-center">
{isOpen ? (
<ChevronRight
className="cursor-pointer"
onClick={() => setIsOpen((prev) => !prev)}
/>
) : (
<ChevronLeft
className="cursor-pointer"
onClick={() => setIsOpen((prev) => !prev)}
/>
)}
<div className="flex items-center border-solid border-2">
<div className=" p-1">
{isOpen ? (
<ChevronRight
className="cursor-pointer"
onClick={() => setIsOpen((prev) => !prev)}
/>
) : (
<ChevronLeft
className="cursor-pointer"
onClick={() => setIsOpen((prev) => !prev)}
/>
)}
</div>
<div className={isOpen ? drawerStyles["open"] : drawerStyles["closed"]}>
<AssistentBuilder />
</div>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"use client";
import { TMessage } from "@/data-provider/types";
import {
MessageType,
TMessage,
TToolCall,
TToolResult,
} from "@/data-provider/types";
import MessageItem from "./message-item";
import { useParams } from "next/navigation";
import { ReactNode, useEffect, useRef } from "react";
import { ReactNode, useEffect, useRef, useState } from "react";
import Spinner from "@/components/ui/spinner";
import { useChatMessages } from "@/hooks/useChat";
import ToolContainer from "../../tools/tool-container";
import { ToolResult } from "../../tools/tool-result";

type Props = {
isStreaming?: boolean;
Expand Down Expand Up @@ -52,9 +59,31 @@ export default function MessagesContainer({
return (
<div className="h-full flex flex-col">
<div className="p-6 overflow-y-scroll" ref={divRef}>
{messages?.map((message, index) => (
<MessageItem message={message} key={`${message.id}-${index}`} />
))}
{messages?.map((message, index) => {
const isToolCall =
message.tool_calls?.length && message.tool_calls.length > 0;

const isToolResult = message.type === MessageType.TOOL;

if (isToolResult) {
return (
<ToolResult toolResult={message} key={`${message.id}-${index}`} />
);
}

if (isToolCall) {
return (
<ToolContainer
toolCalls={message.tool_calls as TToolCall[]}
key={`${message.id}-${index}`}
/>
);
}

return (
<MessageItem message={message} key={`${message.id}-${index}`} />
);
})}
</div>
{composer}
</div>
Expand Down
6 changes: 5 additions & 1 deletion ui/src/components/features/chat-panel/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ export default function ChatPanel() {

setUserMessage("");

await startStream(input, "", threadId, assistant.id);
await startStream({
input,
thread_id: threadId,
assistant_id: assistant.id,
});
};

if (!threadState && isError)
Expand Down
20 changes: 20 additions & 0 deletions ui/src/components/features/tools/tool-container.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { TToolCall } from "@/data-provider/types";
import ToolQuery from "./tool-query";

type TToolContainer = {
toolCalls: TToolCall[];
};

export default function ToolContainer({ toolCalls }: TToolContainer) {
return (
<div className="flex flex-col py-2 px-3 text-base md:px-4 my-4 mr-auto w-3/4 md:px-5 lg:px-1 xl:px-5 flex-col gap-4">
{toolCalls.map((toolCall) => (
<ToolQuery
query={toolCall.args.query}
tool={toolCall.name}
key={toolCall.id}
/>
))}
</div>
);
}
24 changes: 24 additions & 0 deletions ui/src/components/features/tools/tool-query.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { List, Search } from "lucide-react";

const icons = [
{ title: "query", icon: <List /> },
{ title: "search", icon: <Search /> },
];

type TToolQueryProps = {
query: string;
tool: string;
};

export default function ToolQuery({ query, tool }: TToolQueryProps) {
return (
<>
<h2 className="flex gap-2 rounded-sm border-2 p-2">
<List /> Query: {query}
</h2>
<h2 className="flex gap-2 rounded-sm border-2 p-2">
<Search /> Using: {tool}
</h2>
</>
);
}
78 changes: 78 additions & 0 deletions ui/src/components/features/tools/tool-result.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import {
Accordion,
AccordionItem,
AccordionTrigger,
AccordionContent,
} from "@/components/ui/accordion";
import { Button } from "@/components/ui/button";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "@/components/ui/collapsible";
import { TMessage } from "@/data-provider/types";
import { CaretSortIcon } from "@radix-ui/react-icons";
import { MoveUpRight } from "lucide-react";
import Link from "next/link";
import { useState } from "react";

const AccordionToolContent = ({
content,
url,
}: {
content: string;
url: string;
}) => {
return (
<Accordion type="single" collapsible>
<AccordionItem value={url}>
<AccordionTrigger className="text-left">
<div className="flex w-full">
{url}
<Button variant="outline" type="button" className="ml-auto mr-2">
<Link
href={url}
target="_blank"
className="flex gap-2 items-center"
>
Visit
<MoveUpRight className="w-4 h-4" />
</Link>
</Button>
</div>
</AccordionTrigger>
<AccordionContent>{content}</AccordionContent>
</AccordionItem>
</Accordion>
);
};

export function ToolResult({ toolResult }: { toolResult: TMessage }) {
const [isOpen, setIsOpen] = useState(false);

return (
<div className="flex flex-col py-2 px-3 text-base md:px-4 mr-auto w-3/4 md:px-5 lg:px-1 xl:px-5 flex-col gap-4">
<Collapsible open={isOpen} onOpenChange={setIsOpen}>
<div className="flex items-center justify-between">
<CollapsibleTrigger asChild>
<Button variant="outline" type="button" className="gap-2 mb-2">
<span>Tool results: {toolResult.name}</span>
<CaretSortIcon className="h-5 w-5" />
</Button>
</CollapsibleTrigger>
</div>
<CollapsibleContent className="space-y-2">
{Array.isArray(toolResult.content)
? toolResult.content.map((item) => (
<AccordionToolContent
url={item.url}
content={item.content}
key={item.url}
/>
))
: toolResult.content}
</CollapsibleContent>
</Collapsible>
</div>
);
}
30 changes: 15 additions & 15 deletions ui/src/components/ui/accordion.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"use client"
"use client";

import * as React from "react"
import * as AccordionPrimitive from "@radix-ui/react-accordion"
import { ChevronDownIcon } from "@radix-ui/react-icons"
import * as React from "react";
import * as AccordionPrimitive from "@radix-ui/react-accordion";
import { ChevronDownIcon } from "@radix-ui/react-icons";

import { cn } from "@/lib/utils"
import { cn } from "@/lib/utils";

const Accordion = AccordionPrimitive.Root
const Accordion = AccordionPrimitive.Root;

const AccordionItem = React.forwardRef<
React.ElementRef<typeof AccordionPrimitive.Item>,
Expand All @@ -17,8 +17,8 @@ const AccordionItem = React.forwardRef<
className={cn("border-b", className)}
{...props}
/>
))
AccordionItem.displayName = "AccordionItem"
));
AccordionItem.displayName = "AccordionItem";

const AccordionTrigger = React.forwardRef<
React.ElementRef<typeof AccordionPrimitive.Trigger>,
Expand All @@ -28,17 +28,17 @@ const AccordionTrigger = React.forwardRef<
<AccordionPrimitive.Trigger
ref={ref}
className={cn(
"flex flex-1 items-center justify-between py-4 text-sm font-medium transition-all hover:underline [&[data-state=open]>svg]:rotate-180",
className
"flex flex-1 items-center justify-between py-4 text-sm font-medium transition-all [&[data-state=open]>svg]:rotate-180",
className,
)}
{...props}
>
{children}
<ChevronDownIcon className="h-4 w-4 shrink-0 text-muted-foreground transition-transform duration-200" />
</AccordionPrimitive.Trigger>
</AccordionPrimitive.Header>
))
AccordionTrigger.displayName = AccordionPrimitive.Trigger.displayName
));
AccordionTrigger.displayName = AccordionPrimitive.Trigger.displayName;

const AccordionContent = React.forwardRef<
React.ElementRef<typeof AccordionPrimitive.Content>,
Expand All @@ -51,7 +51,7 @@ const AccordionContent = React.forwardRef<
>
<div className={cn("pb-4 pt-0", className)}>{children}</div>
</AccordionPrimitive.Content>
))
AccordionContent.displayName = AccordionPrimitive.Content.displayName
));
AccordionContent.displayName = AccordionPrimitive.Content.displayName;

export { Accordion, AccordionItem, AccordionTrigger, AccordionContent }
export { Accordion, AccordionItem, AccordionTrigger, AccordionContent };
11 changes: 11 additions & 0 deletions ui/src/components/ui/collapsible.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"use client"

import * as CollapsiblePrimitive from "@radix-ui/react-collapsible"

const Collapsible = CollapsiblePrimitive.Root

const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger

const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent

export { Collapsible, CollapsibleTrigger, CollapsibleContent }
30 changes: 26 additions & 4 deletions ui/src/data-provider/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export const formSchema = z.object({
type: z.string().nullable(),
agent_type: z.string().optional(),
llm_type: z.string(),
retrieval_description: z.string(),
retrieval_description: z.string().optional(),
system_message: z.string(),
tools: z.array(z.any()),
}),
Expand All @@ -38,6 +38,7 @@ export type TRunInput = {
export enum MessageType {
AI = "ai",
HUMAN = "human",
TOOL = "tool",
}

export type TConfigurable = {
Expand Down Expand Up @@ -168,21 +169,42 @@ export type TThreadState = {
next: string[];
};

export type TMessage = {
export type TToolCall = {
name: string;
args: {
query: string;
};
id: string;
};

export type TToolResult = {
content: string | [];
type: string;
id: string;
tool_call_id: string;
};

export type TToolContent = {
url: string;
content: string;
};

export type TMessage = {
content: string | TToolContent[];
additional_kwargs?: {
additional_kwargs?: {};
example?: boolean;
};
responsoe_metadata?: {
response_metadata?: {
finish_reason?: boolean;
};
type: string;
name?: string | null;
id: string;
example: boolean;
tool_calls?: string[];
tool_calls?: TToolCall[];
invalid_tool_calls?: string[];
tool_call_id?: string;
};

export type TUpdateMessageRequest = {
Expand Down
14 changes: 8 additions & 6 deletions ui/src/hooks/useStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ import { TMessage, TStreamState } from "@/data-provider/types";
import { fetchEventSource } from "@microsoft/fetch-event-source";
import { useCallback, useState } from "react";

type TStartStreamProps = {
input: TMessage[] | Record<string, any> | null;
thread_id: string;
assistant_id: string;
user_id?: string;
};

export const useStream = () => {
const [current, setCurrent] = useState<TStreamState | null>(null);
const [controller, setController] = useState<AbortController | null>(null);

const startStream = useCallback(
async (
input: TMessage[] | Record<string, any> | null,
user_id: string,
thread_id: string,
assistant_id: string,
) => {
async ({ input, thread_id, assistant_id, user_id }: TStartStreamProps) => {
const controller = new AbortController();
setController(controller);
setCurrent({ status: "inflight", messages: [] });
Expand Down

0 comments on commit d504908

Please sign in to comment.