diff --git a/backend/app/agent.py b/backend/app/agent.py index 23ea62a3..5df7c7ef 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -11,13 +11,15 @@ from app.agent_types.google_agent import get_google_agent_executor from app.agent_types.openai_agent import get_openai_agent_executor from app.agent_types.xml_agent import get_xml_agent_executor +from app.chatbot import get_chatbot_executor from app.checkpoint import RedisCheckpoint from app.llms import ( get_anthropic_llm, get_google_llm, - get_openai_llm, get_mixtral_fireworks, + get_openai_llm, ) +from app.retrieval import get_retrieval_executor from app.tools import ( RETRIEVAL_DESCRIPTION, TOOL_OPTIONS, @@ -26,8 +28,6 @@ get_retrieval_tool, get_retriever, ) -from app.chatbot import get_chatbot_executor -from app.retrieval import get_retrieval_executor class AgentType(str, Enum): diff --git a/backend/app/agent_types/google_agent.py b/backend/app/agent_types/google_agent.py index daf8752e..499dfbf2 100644 --- a/backend/app/agent_types/google_agent.py +++ b/backend/app/agent_types/google_agent.py @@ -9,6 +9,8 @@ from langgraph.graph.message import MessageGraph from langgraph.prebuilt import ToolExecutor, ToolInvocation +from app.message_types import LiberalFunctionMessage + def get_google_agent_executor( tools: list[BaseTool], @@ -17,7 +19,16 @@ def get_google_agent_executor( checkpoint: BaseCheckpointSaver, ): def _get_messages(messages): - return [SystemMessage(content=system_message)] + messages + msgs = [] + for m in messages: + if isinstance(m, LiberalFunctionMessage): + _dict = m.dict() + _dict["content"] = str(_dict["content"]) + m_c = FunctionMessage(**_dict) + msgs.append(m_c) + else: + msgs.append(m) + return [SystemMessage(content=system_message)] + msgs if tools: llm_with_tools = llm.bind(functions=tools) @@ -51,7 +62,7 @@ async def call_tool(messages): # We call the tool_executor and get back a response response = await tool_executor.ainvoke(action) # We use the response to create a FunctionMessage - function_message = FunctionMessage(content=str(response), name=action.tool) + function_message = LiberalFunctionMessage(content=response, name=action.tool) # We return a list, because this will get added to the existing list return function_message diff --git a/backend/app/agent_types/openai_agent.py b/backend/app/agent_types/openai_agent.py index c4bf05d0..93ed1215 100644 --- a/backend/app/agent_types/openai_agent.py +++ b/backend/app/agent_types/openai_agent.py @@ -1,15 +1,16 @@ import json -from langchain.schema.messages import ToolMessage from langchain.tools import BaseTool from langchain.tools.render import format_tool_to_openai_tool from langchain_core.language_models.base import LanguageModelLike -from langchain_core.messages import SystemMessage +from langchain_core.messages import SystemMessage, ToolMessage from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph import END from langgraph.graph.message import MessageGraph from langgraph.prebuilt import ToolExecutor, ToolInvocation +from app.message_types import LiberalToolMessage + def get_openai_agent_executor( tools: list[BaseTool], @@ -18,7 +19,17 @@ def get_openai_agent_executor( checkpoint: BaseCheckpointSaver, ): def _get_messages(messages): - return [SystemMessage(content=system_message)] + messages + msgs = [] + for m in messages: + if isinstance(m, LiberalToolMessage): + _dict = m.dict() + _dict["content"] = str(_dict["content"]) + m_c = ToolMessage(**_dict) + msgs.append(m_c) + else: + msgs.append(m) + + return [SystemMessage(content=system_message)] + msgs if tools: llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools]) @@ -58,9 +69,9 @@ async def call_tool(messages): responses = await tool_executor.abatch(actions) # We use the response to create a ToolMessage tool_messages = [ - ToolMessage( + LiberalToolMessage( tool_call_id=tool_call["id"], - content=json.dumps(response), + content=response, additional_kwargs={"name": tool_call["function"]["name"]}, ) for tool_call, response in zip( diff --git a/backend/app/agent_types/xml_agent.py b/backend/app/agent_types/xml_agent.py index 2a1654cc..6e194f76 100644 --- a/backend/app/agent_types/xml_agent.py +++ b/backend/app/agent_types/xml_agent.py @@ -1,14 +1,19 @@ -from langchain.schema.messages import FunctionMessage from langchain.tools import BaseTool from langchain.tools.render import render_text_description from langchain_core.language_models.base import LanguageModelLike -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph import END from langgraph.graph.message import MessageGraph from langgraph.prebuilt import ToolExecutor, ToolInvocation from app.agent_types.prompts import xml_template +from app.message_types import LiberalFunctionMessage def _collapse_messages(messages): @@ -39,6 +44,11 @@ def construct_chat_history(messages): collapsed_messages.append(_collapse_messages(temp_messages)) temp_messages = [] collapsed_messages.append(message) + elif isinstance(message, LiberalFunctionMessage): + _dict = message.dict() + _dict["content"] = str(_dict["content"]) + m_c = FunctionMessage(**_dict) + temp_messages.append(m_c) else: temp_messages.append(message) @@ -61,7 +71,7 @@ def get_xml_agent_executor( tool_names=", ".join([t.name for t in tools]), ) - llm_with_stop = llm.bind(stop=[""]) + llm_with_stop = llm.bind(stop=["", ""]) def _get_messages(messages): return [ @@ -87,9 +97,12 @@ async def call_tool(messages): # We construct an ToolInvocation from the function_call tool, tool_input = last_message.content.split("") _tool = tool.split("")[1] - _tool_input = tool_input.split("")[1] - if "" in _tool_input: - _tool_input = _tool_input.split("")[0] + if "" not in _tool: + _tool_input = "" + else: + _tool_input = tool_input.split("")[1] + if "" in _tool_input: + _tool_input = _tool_input.split("")[0] action = ToolInvocation( tool=_tool, tool_input=_tool_input, @@ -97,7 +110,7 @@ async def call_tool(messages): # We call the tool_executor and get back a response response = await tool_executor.ainvoke(action) # We use the response to create a FunctionMessage - function_message = FunctionMessage(content=str(response), name=action.tool) + function_message = LiberalFunctionMessage(content=response, name=action.tool) # We return a list, because this will get added to the existing list return function_message diff --git a/backend/app/message_types.py b/backend/app/message_types.py new file mode 100644 index 00000000..c38e66e4 --- /dev/null +++ b/backend/app/message_types.py @@ -0,0 +1,11 @@ +from typing import Any + +from langchain_core.messages import FunctionMessage, ToolMessage + + +class LiberalFunctionMessage(FunctionMessage): + content: Any + + +class LiberalToolMessage(ToolMessage): + content: Any diff --git a/backend/app/retrieval.py b/backend/app/retrieval.py index 42dd6093..4a713129 100644 --- a/backend/app/retrieval.py +++ b/backend/app/retrieval.py @@ -1,24 +1,15 @@ import json -from typing import Any + from langchain_core.language_models.base import LanguageModelLike -from langchain_core.messages import ( - SystemMessage, - HumanMessage, - AIMessage, - FunctionMessage, -) -from langchain_core.runnables import chain +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.prompts import PromptTemplate from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables import chain from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph import END from langgraph.graph.message import MessageGraph -from langchain_core.prompts import PromptTemplate - - -# This is sadly needed to allow for arbitrary typed content -class LiberalFunctionMessage(FunctionMessage): - content: Any +from app.message_types import LiberalFunctionMessage search_prompt = PromptTemplate.from_template( """Given the conversation below, come up with a search query to look up. @@ -67,7 +58,7 @@ def _get_messages(messages): ] + chat_history @chain - def get_search_query(messages): + async def get_search_query(messages): convo = [] for m in messages: if isinstance(m, AIMessage): @@ -76,11 +67,11 @@ def get_search_query(messages): if isinstance(m, HumanMessage): convo.append(f"Human: {m.content}") conversation = "\n".join(convo) - prompt = search_prompt.invoke({"conversation": conversation}) - response = llm.invoke(prompt) + prompt = await search_prompt.ainvoke({"conversation": conversation}) + response = await llm.ainvoke(prompt) return response.content - def invoke_retrieval(messages): + async def invoke_retrieval(messages): if len(messages) == 1: human_input = messages[-1].content return AIMessage( @@ -93,7 +84,7 @@ def invoke_retrieval(messages): }, ) else: - search_query = get_search_query.invoke(messages) + search_query = await get_search_query.ainvoke(messages) return AIMessage( content="", additional_kwargs={ @@ -104,10 +95,10 @@ def invoke_retrieval(messages): }, ) - def retrieve(messages): + async def retrieve(messages): params = messages[-1].additional_kwargs["function_call"] query = json.loads(params["arguments"])["query"] - response = retriever.invoke(query) + response = await retriever.ainvoke(query) msg = LiberalFunctionMessage(name="retrieval", content=response) return msg diff --git a/frontend/src/components/Config.tsx b/frontend/src/components/Config.tsx index d1dad7c5..d06efadc 100644 --- a/frontend/src/components/Config.tsx +++ b/frontend/src/components/Config.tsx @@ -377,7 +377,19 @@ export function Config(props: { ); return ( -
+
{ + e.preventDefault(); + e.stopPropagation(); + const form = e.target as HTMLFormElement; + const key = form.key.value; + if (!key) return; + setInflight(true); + await props.saveConfig(key, values!, files, isPublic); + setInflight(false); + }} + > {settings} {typeField && ( )} - { - e.preventDefault(); - e.stopPropagation(); - const form = e.target as HTMLFormElement; - const key = form.key.value; - if (!key) return; - setInflight(true); - await props.saveConfig(key, values!, files, isPublic); - setInflight(false); - }} - > - {!props.config && typeSpec?.files && ( - + {!props.config && typeSpec?.files && ( + + )} +
- {orderBy( - Object.entries( - props.configSchema?.properties.configurable.properties ?? {} - ), - ([key]) => ORDER.indexOf(last(key.split("/"))!) - ).map(([key, value]) => { - const title = value.title; - if (key.split("/")[0].includes("==")) { - const [parentKey, parentValue] = key.split("/")[0].split("=="); - if (values?.configurable?.[parentKey] !== parentValue) { - return null; - } - } else { - return null; - } - if ( - last(key.split("/")) === "retrieval_description" && - !files.length - ) { + > + {orderBy( + Object.entries( + props.configSchema?.properties.configurable.properties ?? {} + ), + ([key]) => ORDER.indexOf(last(key.split("/"))!) + ).map(([key, value]) => { + const title = value.title; + if (key.split("/")[0].includes("==")) { + const [parentKey, parentValue] = key.split("/")[0].split("=="); + if (values?.configurable?.[parentKey] !== parentValue) { return null; } - if (value.type === "string" && value.enum) { - return ( - - setValues({ - ...values, - configurable: { ...values!.configurable, [key]: value }, - }) - } - readonly={readonly} - /> - ); - } else if (value.type === "string") { - return ( - - setValues({ - ...values, - configurable: { ...values!.configurable, [key]: value }, - }) - } - readonly={readonly} - /> - ); - } else if ( - value.type === "array" && - value.items?.type === "string" && - value.items?.enum - ) { - return ( - - setValues({ - ...values, - configurable: { ...values!.configurable, [key]: value }, - }) - } - readonly={readonly} - descriptions={TOOL_DESCRIPTIONS} - /> - ); - } - })} -
- -
+ } else { + return null; + } + if ( + last(key.split("/")) === "retrieval_description" && + !files.length + ) { + return null; + } + if (value.type === "string" && value.enum) { + return ( + + setValues({ + ...values, + configurable: { ...values!.configurable, [key]: value }, + }) + } + readonly={readonly} + /> + ); + } else if (value.type === "string") { + return ( + + setValues({ + ...values, + configurable: { ...values!.configurable, [key]: value }, + }) + } + readonly={readonly} + /> + ); + } else if ( + value.type === "array" && + value.items?.type === "string" && + value.items?.enum + ) { + return ( + + setValues({ + ...values, + configurable: { ...values!.configurable, [key]: value }, + }) + } + readonly={readonly} + descriptions={TOOL_DESCRIPTIONS} + /> + ); + } + })} + + ); } diff --git a/frontend/src/components/Document.tsx b/frontend/src/components/Document.tsx new file mode 100644 index 00000000..39e84a72 --- /dev/null +++ b/frontend/src/components/Document.tsx @@ -0,0 +1,100 @@ +import { useMemo, useState } from "react"; +import { cn } from "../utils/cn"; +import { ChevronDownIcon, ChevronRightIcon } from "@heroicons/react/24/outline"; + +export interface PageDocument { + page_content: string; + metadata: Record; +} + +function PageDocument(props: { document: PageDocument; className?: string }) { + const [open, setOpen] = useState(false); + + const metadata = useMemo(() => { + return Object.keys(props.document.metadata) + .sort((a, b) => { + const aValue = JSON.stringify(props.document.metadata[a]); + const bValue = JSON.stringify(props.document.metadata[b]); + + const aLines = aValue.split("\n"); + const bLines = bValue.split("\n"); + + if (aLines.length !== bLines.length) { + return aLines.length - bLines.length; + } + + return aValue.length - bValue.length; + }) + .map((key) => { + const value = props.document.metadata[key]; + return { + key, + value: + typeof value === "string" || typeof value === "number" + ? `${value}` + : JSON.stringify(value), + }; + }); + }, [props.document.metadata]); + + if (!open) { + return ( + + ); + } + + return ( + + ); +} + +export function DocumentList(props: { documents: PageDocument[] }) { + return ( +
+
+ {props.documents.map((document, idx) => ( + + ))} +
+
+ ); +} diff --git a/frontend/src/components/FileUpload.tsx b/frontend/src/components/FileUpload.tsx index 6bbac49d..e7276893 100644 --- a/frontend/src/components/FileUpload.tsx +++ b/frontend/src/components/FileUpload.tsx @@ -45,6 +45,7 @@ export function FileUploadDropzone(props: { state: DropzoneState; files: File[]; setFiles: React.Dispatch>; + className?: string; }) { const { getRootProps, getInputProps, fileRejections } = props.state; @@ -74,7 +75,7 @@ export function FileUploadDropzone(props: { ); return ( -
+