Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render documents in UI #148

Merged
merged 6 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from app.agent_types.google_agent import get_google_agent_executor
from app.agent_types.openai_agent import get_openai_agent_executor
from app.agent_types.xml_agent import get_xml_agent_executor
from app.chatbot import get_chatbot_executor
from app.checkpoint import RedisCheckpoint
from app.llms import (
get_anthropic_llm,
get_google_llm,
get_openai_llm,
get_mixtral_fireworks,
get_openai_llm,
)
from app.retrieval import get_retrieval_executor
from app.tools import (
RETRIEVAL_DESCRIPTION,
TOOL_OPTIONS,
Expand All @@ -26,8 +28,6 @@
get_retrieval_tool,
get_retriever,
)
from app.chatbot import get_chatbot_executor
from app.retrieval import get_retrieval_executor


class AgentType(str, Enum):
Expand Down
15 changes: 13 additions & 2 deletions backend/app/agent_types/google_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from langgraph.graph.message import MessageGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

from app.message_types import LiberalFunctionMessage


def get_google_agent_executor(
tools: list[BaseTool],
Expand All @@ -17,7 +19,16 @@ def get_google_agent_executor(
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages
msgs = []
for m in messages:
if isinstance(m, LiberalFunctionMessage):
_dict = m.dict()
_dict["content"] = str(_dict["content"])
m_c = FunctionMessage(**_dict)
msgs.append(m_c)
else:
msgs.append(m)
return [SystemMessage(content=system_message)] + msgs

if tools:
llm_with_tools = llm.bind(functions=tools)
Expand Down Expand Up @@ -51,7 +62,7 @@ async def call_tool(messages):
# We call the tool_executor and get back a response
response = await tool_executor.ainvoke(action)
# We use the response to create a FunctionMessage
function_message = FunctionMessage(content=str(response), name=action.tool)
function_message = LiberalFunctionMessage(content=response, name=action.tool)
# We return a list, because this will get added to the existing list
return function_message

Expand Down
21 changes: 16 additions & 5 deletions backend/app/agent_types/openai_agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json

from langchain.schema.messages import ToolMessage
from langchain.tools import BaseTool
from langchain.tools.render import format_tool_to_openai_tool
from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import SystemMessage
from langchain_core.messages import SystemMessage, ToolMessage
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

from app.message_types import LiberalToolMessage


def get_openai_agent_executor(
tools: list[BaseTool],
Expand All @@ -18,7 +19,17 @@ def get_openai_agent_executor(
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages
msgs = []
for m in messages:
if isinstance(m, LiberalToolMessage):
_dict = m.dict()
_dict["content"] = str(_dict["content"])
m_c = ToolMessage(**_dict)
msgs.append(m_c)
else:
msgs.append(m)

return [SystemMessage(content=system_message)] + msgs

if tools:
llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools])
Expand Down Expand Up @@ -58,9 +69,9 @@ async def call_tool(messages):
responses = await tool_executor.abatch(actions)
# We use the response to create a ToolMessage
tool_messages = [
ToolMessage(
LiberalToolMessage(
tool_call_id=tool_call["id"],
content=json.dumps(response),
content=response,
additional_kwargs={"name": tool_call["function"]["name"]},
)
for tool_call, response in zip(
Expand Down
27 changes: 20 additions & 7 deletions backend/app/agent_types/xml_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from langchain.schema.messages import FunctionMessage
from langchain.tools import BaseTool
from langchain.tools.render import render_text_description
from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

from app.agent_types.prompts import xml_template
from app.message_types import LiberalFunctionMessage


def _collapse_messages(messages):
Expand Down Expand Up @@ -39,6 +44,11 @@ def construct_chat_history(messages):
collapsed_messages.append(_collapse_messages(temp_messages))
temp_messages = []
collapsed_messages.append(message)
elif isinstance(message, LiberalFunctionMessage):
_dict = message.dict()
_dict["content"] = str(_dict["content"])
m_c = FunctionMessage(**_dict)
temp_messages.append(m_c)
else:
temp_messages.append(message)

Expand All @@ -61,7 +71,7 @@ def get_xml_agent_executor(
tool_names=", ".join([t.name for t in tools]),
)

llm_with_stop = llm.bind(stop=["</tool_input>"])
llm_with_stop = llm.bind(stop=["</tool_input>", "<observation>"])

def _get_messages(messages):
return [
Expand All @@ -87,17 +97,20 @@ async def call_tool(messages):
# We construct an ToolInvocation from the function_call
tool, tool_input = last_message.content.split("</tool>")
_tool = tool.split("<tool>")[1]
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
if "<tool_input>" not in _tool:
_tool_input = ""
else:
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
action = ToolInvocation(
tool=_tool,
tool_input=_tool_input,
)
# We call the tool_executor and get back a response
response = await tool_executor.ainvoke(action)
# We use the response to create a FunctionMessage
function_message = FunctionMessage(content=str(response), name=action.tool)
function_message = LiberalFunctionMessage(content=response, name=action.tool)
# We return a list, because this will get added to the existing list
return function_message

Expand Down
11 changes: 11 additions & 0 deletions backend/app/message_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any

from langchain_core.messages import FunctionMessage, ToolMessage


class LiberalFunctionMessage(FunctionMessage):
content: Any


class LiberalToolMessage(ToolMessage):
content: Any
33 changes: 12 additions & 21 deletions backend/app/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import json
from typing import Any

from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import (
SystemMessage,
HumanMessage,
AIMessage,
FunctionMessage,
)
from langchain_core.runnables import chain
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import chain
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langchain_core.prompts import PromptTemplate


# This is sadly needed to allow for arbitrary typed content
class LiberalFunctionMessage(FunctionMessage):
content: Any

from app.message_types import LiberalFunctionMessage

search_prompt = PromptTemplate.from_template(
"""Given the conversation below, come up with a search query to look up.
Expand Down Expand Up @@ -67,7 +58,7 @@ def _get_messages(messages):
] + chat_history

@chain
def get_search_query(messages):
async def get_search_query(messages):
convo = []
for m in messages:
if isinstance(m, AIMessage):
Expand All @@ -76,11 +67,11 @@ def get_search_query(messages):
if isinstance(m, HumanMessage):
convo.append(f"Human: {m.content}")
conversation = "\n".join(convo)
prompt = search_prompt.invoke({"conversation": conversation})
response = llm.invoke(prompt)
prompt = await search_prompt.ainvoke({"conversation": conversation})
response = await llm.ainvoke(prompt)
return response.content

def invoke_retrieval(messages):
async def invoke_retrieval(messages):
if len(messages) == 1:
human_input = messages[-1].content
return AIMessage(
Expand All @@ -93,7 +84,7 @@ def invoke_retrieval(messages):
},
)
else:
search_query = get_search_query.invoke(messages)
search_query = await get_search_query.ainvoke(messages)
return AIMessage(
content="",
additional_kwargs={
Expand All @@ -104,10 +95,10 @@ def invoke_retrieval(messages):
},
)

def retrieve(messages):
async def retrieve(messages):
params = messages[-1].additional_kwargs["function_call"]
query = json.loads(params["arguments"])["query"]
response = retriever.invoke(query)
response = await retriever.ainvoke(query)
msg = LiberalFunctionMessage(name="retrieval", content=response)
return msg

Expand Down
Loading
Loading