Skip to content

Commit

Permalink
Merge pull request #148 from langchain-ai/nc/jan29/render-documents
Browse files Browse the repository at this point in the history
Render documents in UI
  • Loading branch information
nfcampos authored Jan 29, 2024
2 parents fd4718d + 25c3c4d commit 9743da5
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 151 deletions.
6 changes: 3 additions & 3 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from app.agent_types.google_agent import get_google_agent_executor
from app.agent_types.openai_agent import get_openai_agent_executor
from app.agent_types.xml_agent import get_xml_agent_executor
from app.chatbot import get_chatbot_executor
from app.checkpoint import RedisCheckpoint
from app.llms import (
get_anthropic_llm,
get_google_llm,
get_openai_llm,
get_mixtral_fireworks,
get_openai_llm,
)
from app.retrieval import get_retrieval_executor
from app.tools import (
RETRIEVAL_DESCRIPTION,
TOOL_OPTIONS,
Expand All @@ -26,8 +28,6 @@
get_retrieval_tool,
get_retriever,
)
from app.chatbot import get_chatbot_executor
from app.retrieval import get_retrieval_executor


class AgentType(str, Enum):
Expand Down
15 changes: 13 additions & 2 deletions backend/app/agent_types/google_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from langgraph.graph.message import MessageGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

from app.message_types import LiberalFunctionMessage


def get_google_agent_executor(
tools: list[BaseTool],
Expand All @@ -17,7 +19,16 @@ def get_google_agent_executor(
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages
msgs = []
for m in messages:
if isinstance(m, LiberalFunctionMessage):
_dict = m.dict()
_dict["content"] = str(_dict["content"])
m_c = FunctionMessage(**_dict)
msgs.append(m_c)
else:
msgs.append(m)
return [SystemMessage(content=system_message)] + msgs

if tools:
llm_with_tools = llm.bind(functions=tools)
Expand Down Expand Up @@ -51,7 +62,7 @@ async def call_tool(messages):
# We call the tool_executor and get back a response
response = await tool_executor.ainvoke(action)
# We use the response to create a FunctionMessage
function_message = FunctionMessage(content=str(response), name=action.tool)
function_message = LiberalFunctionMessage(content=response, name=action.tool)
# We return a list, because this will get added to the existing list
return function_message

Expand Down
21 changes: 16 additions & 5 deletions backend/app/agent_types/openai_agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json

from langchain.schema.messages import ToolMessage
from langchain.tools import BaseTool
from langchain.tools.render import format_tool_to_openai_tool
from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import SystemMessage
from langchain_core.messages import SystemMessage, ToolMessage
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

from app.message_types import LiberalToolMessage


def get_openai_agent_executor(
tools: list[BaseTool],
Expand All @@ -18,7 +19,17 @@ def get_openai_agent_executor(
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages
msgs = []
for m in messages:
if isinstance(m, LiberalToolMessage):
_dict = m.dict()
_dict["content"] = str(_dict["content"])
m_c = ToolMessage(**_dict)
msgs.append(m_c)
else:
msgs.append(m)

return [SystemMessage(content=system_message)] + msgs

if tools:
llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools])
Expand Down Expand Up @@ -58,9 +69,9 @@ async def call_tool(messages):
responses = await tool_executor.abatch(actions)
# We use the response to create a ToolMessage
tool_messages = [
ToolMessage(
LiberalToolMessage(
tool_call_id=tool_call["id"],
content=json.dumps(response),
content=response,
additional_kwargs={"name": tool_call["function"]["name"]},
)
for tool_call, response in zip(
Expand Down
27 changes: 20 additions & 7 deletions backend/app/agent_types/xml_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from langchain.schema.messages import FunctionMessage
from langchain.tools import BaseTool
from langchain.tools.render import render_text_description
from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

from app.agent_types.prompts import xml_template
from app.message_types import LiberalFunctionMessage


def _collapse_messages(messages):
Expand Down Expand Up @@ -39,6 +44,11 @@ def construct_chat_history(messages):
collapsed_messages.append(_collapse_messages(temp_messages))
temp_messages = []
collapsed_messages.append(message)
elif isinstance(message, LiberalFunctionMessage):
_dict = message.dict()
_dict["content"] = str(_dict["content"])
m_c = FunctionMessage(**_dict)
temp_messages.append(m_c)
else:
temp_messages.append(message)

Expand All @@ -61,7 +71,7 @@ def get_xml_agent_executor(
tool_names=", ".join([t.name for t in tools]),
)

llm_with_stop = llm.bind(stop=["</tool_input>"])
llm_with_stop = llm.bind(stop=["</tool_input>", "<observation>"])

def _get_messages(messages):
return [
Expand All @@ -87,17 +97,20 @@ async def call_tool(messages):
# We construct an ToolInvocation from the function_call
tool, tool_input = last_message.content.split("</tool>")
_tool = tool.split("<tool>")[1]
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
if "<tool_input>" not in _tool:
_tool_input = ""
else:
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
action = ToolInvocation(
tool=_tool,
tool_input=_tool_input,
)
# We call the tool_executor and get back a response
response = await tool_executor.ainvoke(action)
# We use the response to create a FunctionMessage
function_message = FunctionMessage(content=str(response), name=action.tool)
function_message = LiberalFunctionMessage(content=response, name=action.tool)
# We return a list, because this will get added to the existing list
return function_message

Expand Down
11 changes: 11 additions & 0 deletions backend/app/message_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any

from langchain_core.messages import FunctionMessage, ToolMessage


class LiberalFunctionMessage(FunctionMessage):
content: Any


class LiberalToolMessage(ToolMessage):
content: Any
33 changes: 12 additions & 21 deletions backend/app/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import json
from typing import Any

from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import (
SystemMessage,
HumanMessage,
AIMessage,
FunctionMessage,
)
from langchain_core.runnables import chain
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import chain
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langchain_core.prompts import PromptTemplate


# This is sadly needed to allow for arbitrary typed content
class LiberalFunctionMessage(FunctionMessage):
content: Any

from app.message_types import LiberalFunctionMessage

search_prompt = PromptTemplate.from_template(
"""Given the conversation below, come up with a search query to look up.
Expand Down Expand Up @@ -67,7 +58,7 @@ def _get_messages(messages):
] + chat_history

@chain
def get_search_query(messages):
async def get_search_query(messages):
convo = []
for m in messages:
if isinstance(m, AIMessage):
Expand All @@ -76,11 +67,11 @@ def get_search_query(messages):
if isinstance(m, HumanMessage):
convo.append(f"Human: {m.content}")
conversation = "\n".join(convo)
prompt = search_prompt.invoke({"conversation": conversation})
response = llm.invoke(prompt)
prompt = await search_prompt.ainvoke({"conversation": conversation})
response = await llm.ainvoke(prompt)
return response.content

def invoke_retrieval(messages):
async def invoke_retrieval(messages):
if len(messages) == 1:
human_input = messages[-1].content
return AIMessage(
Expand All @@ -93,7 +84,7 @@ def invoke_retrieval(messages):
},
)
else:
search_query = get_search_query.invoke(messages)
search_query = await get_search_query.ainvoke(messages)
return AIMessage(
content="",
additional_kwargs={
Expand All @@ -104,10 +95,10 @@ def invoke_retrieval(messages):
},
)

def retrieve(messages):
async def retrieve(messages):
params = messages[-1].additional_kwargs["function_call"]
query = json.loads(params["arguments"])["query"]
response = retriever.invoke(query)
response = await retriever.ainvoke(query)
msg = LiberalFunctionMessage(name="retrieval", content=response)
return msg

Expand Down
Loading

0 comments on commit 9743da5

Please sign in to comment.