Skip to content

Commit

Permalink
Make retrieval tool description configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jan 18, 2024
1 parent 283fb71 commit f5707e9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
17 changes: 15 additions & 2 deletions backend/packages/gizmo-agent/gizmo_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
get_openai_function_agent,
# get_xml_agent,
)
from gizmo_agent.tools import TOOL_OPTIONS, TOOLS, AvailableTools, get_retrieval_tool
from gizmo_agent.tools import (
RETRIEVAL_DESCRIPTION,
TOOL_OPTIONS,
TOOLS,
AvailableTools,
get_retrieval_tool,
)

DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."

Expand All @@ -26,6 +32,7 @@ class ConfigurableAgent(RunnableBinding):
tools: Sequence[str]
agent: GizmoAgentType
system_message: str = DEFAULT_SYSTEM_MESSAGE
retrieval_description: str = RETRIEVAL_DESCRIPTION
assistant_id: Optional[str] = None
user_id: Optional[str] = None

Expand All @@ -36,6 +43,7 @@ def __init__(
agent: GizmoAgentType = GizmoAgentType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
retrieval_description: str = RETRIEVAL_DESCRIPTION,
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[Mapping[str, Any]] = None,
**others: Any,
Expand All @@ -48,7 +56,7 @@ def __init__(
raise ValueError(
"assistant_id must be provided if Retrieval tool is used"
)
_tools.append(get_retrieval_tool(assistant_id))
_tools.append(get_retrieval_tool(assistant_id, retrieval_description))
else:
_tools.append(TOOLS[_tool]())
if agent == GizmoAgentType.GPT_35_TURBO:
Expand All @@ -70,6 +78,7 @@ def __init__(
tools=tools,
agent=agent,
system_message=system_message,
retrieval_description=retrieval_description,
bound=agent_executor,
kwargs=kwargs or {},
config=config or {},
Expand Down Expand Up @@ -110,6 +119,7 @@ class AgentOutput(BaseModel):
agent=GizmoAgentType.GPT_35_TURBO,
tools=[],
system_message=DEFAULT_SYSTEM_MESSAGE,
retrieval_description=RETRIEVAL_DESCRIPTION,
assistant_id=None,
)
.configurable_fields(
Expand All @@ -124,6 +134,9 @@ class AgentOutput(BaseModel):
options=TOOL_OPTIONS,
default=[],
),
retrieval_description=ConfigurableField(
id="retrieval_description", name="Retrieval Description"
),
)
.configurable_alternatives(
ConfigurableField(id="type", name="Bot Type"),
Expand Down
6 changes: 3 additions & 3 deletions backend/packages/gizmo-agent/gizmo_agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ class PythonREPLInput(BaseModel):
query: str = Field(description="python command to run")


RETRIEVER_DESCRIPTION = """Can be used to look up information that was uploaded to this assistant.
RETRIEVAL_DESCRIPTION = """Can be used to look up information that was uploaded to this assistant.
If the user is referencing particular files, that is often a good hint that information may be here."""


def get_retrieval_tool(assistant_id: str):
def get_retrieval_tool(assistant_id: str, description: str):
return create_retriever_tool(
vstore.as_retriever(
search_kwargs={"filter": RedisFilter.tag("namespace") == assistant_id}
),
"Retriever",
RETRIEVER_DESCRIPTION,
description,
)


Expand Down
29 changes: 28 additions & 1 deletion frontend/src/components/Config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ function MultiOptionField(props: {

function PublicLink(props: { assistantId: string }) {
const currentLink = window.location.href;
const link = currentLink.includes('shared_id=') ? currentLink : currentLink + "?shared_id=" + props.assistantId;
const link = currentLink.includes("shared_id=")
? currentLink
: currentLink + "?shared_id=" + props.assistantId;
return (
<div className="flex rounded-md shadow-sm mb-4">
<button
Expand Down Expand Up @@ -320,6 +322,30 @@ export function Config(props: {
readonly={readonly}
/>
);
} else if (
key === "type==agent/retrieval_description" &&
(
values?.configurable?.["type==agent/tools"] as
| string[]
| undefined
)?.includes("Retrieval")
) {
return (
<StringField
key={key}
id={key}
field={value}
title={title}
value={values?.configurable?.[key] as string}
setValue={(value: string) =>
setValues({
...values,
configurable: { ...values!.configurable, [key]: value },
})
}
readonly={readonly}
/>
);
} else if (key === "type==agent/tools") {
return (
<MultiOptionField
Expand Down Expand Up @@ -347,6 +373,7 @@ export function Config(props: {
setFiles={setFiles}
/>
)}
{}
<SingleOptionField
id="public"
field={{
Expand Down

0 comments on commit f5707e9

Please sign in to comment.