diff --git a/backend/app/agent.py b/backend/app/agent.py index 4e3653a3..512d85b2 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -8,8 +8,7 @@ ) from langgraph.checkpoint import CheckpointAt -from app.agent_types.google_agent import get_google_agent_executor -from app.agent_types.openai_agent import get_openai_agent_executor +from app.agent_types.tools_agent import get_tools_agent_executor from app.agent_types.xml_agent import get_xml_agent_executor from app.chatbot import get_chatbot_executor from app.checkpoint import PostgresCheckpoint @@ -82,22 +81,22 @@ def get_agent_executor( ): if agent == AgentType.GPT_35_TURBO: llm = get_openai_llm() - return get_openai_agent_executor( + return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) elif agent == AgentType.GPT_4: llm = get_openai_llm(gpt_4=True) - return get_openai_agent_executor( + return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) elif agent == AgentType.AZURE_OPENAI: llm = get_openai_llm(azure=True) - return get_openai_agent_executor( + return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) elif agent == AgentType.CLAUDE2: llm = get_anthropic_llm() - return get_xml_agent_executor( + return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) elif agent == AgentType.BEDROCK_CLAUDE2: @@ -107,12 +106,12 @@ def get_agent_executor( ) elif agent == AgentType.GEMINI: llm = get_google_llm() - return get_google_agent_executor( + return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) elif agent == AgentType.OLLAMA: llm = get_ollama_llm() - return get_openai_agent_executor( + return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) diff --git a/backend/app/agent_types/google_agent.py b/backend/app/agent_types/google_agent.py deleted file mode 100644 index 769a2347..00000000 --- a/backend/app/agent_types/google_agent.py +++ /dev/null @@ -1,111 +0,0 @@ -import json - -from langchain.schema.messages import FunctionMessage -from langchain.tools import BaseTool -from langchain_core.language_models.base import LanguageModelLike -from langchain_core.messages import SystemMessage -from langgraph.checkpoint import BaseCheckpointSaver -from langgraph.graph import END -from langgraph.graph.message import MessageGraph -from langgraph.prebuilt import ToolExecutor, ToolInvocation - -from app.message_types import LiberalFunctionMessage - - -def get_google_agent_executor( - tools: list[BaseTool], - llm: LanguageModelLike, - system_message: str, - interrupt_before_action: bool, - checkpoint: BaseCheckpointSaver, -): - def _get_messages(messages): - msgs = [] - for m in messages: - if isinstance(m, LiberalFunctionMessage): - _dict = m.dict() - _dict["content"] = str(_dict["content"]) - m_c = FunctionMessage(**_dict) - msgs.append(m_c) - else: - msgs.append(m) - return [SystemMessage(content=system_message)] + msgs - - if tools: - llm_with_tools = llm.bind(functions=tools) - else: - llm_with_tools = llm - agent = _get_messages | llm_with_tools - tool_executor = ToolExecutor(tools) - - # Define the function that determines whether to continue or not - def should_continue(messages): - last_message = messages[-1] - # If there is no function call, then we finish - if "function_call" not in last_message.additional_kwargs: - return "end" - # Otherwise if there is, we continue - else: - return "continue" - - # Define the function to execute tools - async def call_tool(messages): - # Based on the continue condition - # we know the last message involves a function call - last_message = messages[-1] - # We construct an ToolInvocation from the function_call - action = ToolInvocation( - tool=last_message.additional_kwargs["function_call"]["name"], - tool_input=json.loads( - last_message.additional_kwargs["function_call"]["arguments"] - ), - ) - # We call the tool_executor and get back a response - response = await tool_executor.ainvoke(action) - # We use the response to create a FunctionMessage - function_message = LiberalFunctionMessage(content=response, name=action.tool) - # We return a list, because this will get added to the existing list - return function_message - - workflow = MessageGraph() - - # Define the two nodes we will cycle between - workflow.add_node("agent", agent) - workflow.add_node("action", call_tool) - - # Set the entrypoint as `agent` - # This means that this node is the first one called - workflow.set_entry_point("agent") - - # We now add a conditional edge - workflow.add_conditional_edges( - # First, we define the start node. We use `agent`. - # This means these are the edges taken after the `agent` node is called. - "agent", - # Next, we pass in the function that will determine which node is called next. - should_continue, - # Finally we pass in a mapping. - # The keys are strings, and the values are other nodes. - # END is a special node marking that the graph should finish. - # What will happen is we will call `should_continue`, and then the output of that - # will be matched against the keys in this mapping. - # Based on which one it matches, that node will then be called. - { - # If `tools`, then we call the tool node. - "continue": "action", - # Otherwise we finish. - "end": END, - }, - ) - - # We now add a normal edge from `tools` to `agent`. - # This means that after `tools` is called, `agent` node is called next. - workflow.add_edge("action", "agent") - - # Finally, we compile it! - # This compiles it into a LangChain Runnable, - # meaning you can use it as you would any other runnable - return workflow.compile( - checkpointer=checkpoint, - interrupt_before=["action"] if interrupt_before_action else None, - ) diff --git a/backend/app/agent_types/openai_agent.py b/backend/app/agent_types/tools_agent.py similarity index 79% rename from backend/app/agent_types/openai_agent.py rename to backend/app/agent_types/tools_agent.py index 4fb3d80a..0a061af1 100644 --- a/backend/app/agent_types/openai_agent.py +++ b/backend/app/agent_types/tools_agent.py @@ -1,9 +1,14 @@ -import json +from typing import cast from langchain.tools import BaseTool -from langchain.tools.render import format_tool_to_openai_tool from langchain_core.language_models.base import LanguageModelLike -from langchain_core.messages import SystemMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph import END from langgraph.graph.message import MessageGraph @@ -12,7 +17,7 @@ from app.message_types import LiberalToolMessage -def get_openai_agent_executor( +def get_tools_agent_executor( tools: list[BaseTool], llm: LanguageModelLike, system_message: str, @@ -27,13 +32,16 @@ async def _get_messages(messages): _dict["content"] = str(_dict["content"]) m_c = ToolMessage(**_dict) msgs.append(m_c) + elif isinstance(m, FunctionMessage): + # anthropic doesn't like function messages + msgs.append(HumanMessage(content=str(m.content))) else: msgs.append(m) return [SystemMessage(content=system_message)] + msgs if tools: - llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools]) + llm_with_tools = llm.bind_tools(tools) else: llm_with_tools = llm agent = _get_messages | llm_with_tools @@ -43,7 +51,7 @@ async def _get_messages(messages): def should_continue(messages): last_message = messages[-1] # If there is no function call, then we finish - if "tool_calls" not in last_message.additional_kwargs: + if not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: @@ -54,16 +62,13 @@ async def call_tool(messages): actions: list[ToolInvocation] = [] # Based on the continue condition # we know the last message involves a function call - last_message = messages[-1] - for tool_call in last_message.additional_kwargs["tool_calls"]: - function = tool_call["function"] - function_name = function["name"] - _tool_input = json.loads(function["arguments"] or "{}") - # We construct an ToolInvocation from the function_call + last_message = cast(AIMessage, messages[-1]) + for tool_call in last_message.tool_calls: + # We construct a ToolInvocation from the function_call actions.append( ToolInvocation( - tool=function_name, - tool_input=_tool_input, + tool=tool_call["name"], + tool_input=tool_call["args"], ) ) # We call the tool_executor and get back a response @@ -72,12 +77,10 @@ async def call_tool(messages): tool_messages = [ LiberalToolMessage( tool_call_id=tool_call["id"], + name=tool_call["name"], content=response, - additional_kwargs={"name": tool_call["function"]["name"]}, - ) - for tool_call, response in zip( - last_message.additional_kwargs["tool_calls"], responses ) + for tool_call, response in zip(last_message.tool_calls, responses) ] return tool_messages diff --git a/backend/app/checkpoint.py b/backend/app/checkpoint.py index 9e4681b4..7f79e85b 100644 --- a/backend/app/checkpoint.py +++ b/backend/app/checkpoint.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import AsyncIterator, Optional +from langchain_core.messages import BaseMessage from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig from langgraph.checkpoint import BaseCheckpointSaver from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple @@ -9,6 +10,14 @@ from app.lifespan import get_pg_pool +def loads(value: bytes) -> Checkpoint: + loaded: Checkpoint = pickle.loads(value) + for key, value in loaded["channel_values"].items(): + if isinstance(value, list) and all(isinstance(v, BaseMessage) for v in value): + loaded["channel_values"][key] = [v.__class__(**v.__dict__) for v in value] + return loaded + + class PostgresCheckpoint(BaseCheckpointSaver): class Config: arbitrary_types_allowed = True @@ -47,7 +56,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]: "thread_ts": value[1], } }, - pickle.loads(value[0]), + loads(value[0]), { "configurable": { "thread_id": thread_id, @@ -70,7 +79,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: ): return CheckpointTuple( config, - pickle.loads(value[0]), + loads(value[0]), { "configurable": { "thread_id": thread_id, @@ -92,7 +101,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: "thread_ts": value[1], } }, - pickle.loads(value[0]), + loads(value[0]), { "configurable": { "thread_id": thread_id, diff --git a/backend/app/llms.py b/backend/app/llms.py index 02a469f5..b1dcce85 100644 --- a/backend/app/llms.py +++ b/backend/app/llms.py @@ -5,10 +5,11 @@ import boto3 import httpx -from langchain_community.chat_models import BedrockChat, ChatAnthropic, ChatFireworks +from langchain_community.chat_models import BedrockChat, ChatFireworks from langchain_community.chat_models.ollama import ChatOllama from langchain_google_vertexai import ChatVertexAI from langchain_openai import AzureChatOpenAI, ChatOpenAI +from langchain_anthropic import ChatAnthropic logger = logging.getLogger(__name__) @@ -67,7 +68,11 @@ def get_anthropic_llm(bedrock: bool = False): ) model = BedrockChat(model_id="anthropic.claude-v2", client=client) else: - model = ChatAnthropic(temperature=0, max_tokens_to_sample=2000) + model = ChatAnthropic( + model_name="claude-3-haiku-20240307", + max_tokens_to_sample=2000, + temperature=0, + ) return model diff --git a/backend/poetry.lock b/backend/poetry.lock index 3fc44196..5288f3a8 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -112,13 +112,13 @@ frozenlist = ">=1.1.0" [[package]] name = "anthropic" -version = "0.13.0" +version = "0.25.2" description = "The official Python library for the anthropic API" optional = false python-versions = ">=3.7" files = [ - {file = "anthropic-0.13.0-py3-none-any.whl", hash = "sha256:2d4b6a69bf5b31a596669d68820f40f5ed9a9a3333ddaa727166a11ed29275e8"}, - {file = "anthropic-0.13.0.tar.gz", hash = "sha256:b935d2fee12f7dbfcc80398b3da5f20103ece42aecb97d8ce24459e3c4f8ec8a"}, + {file = "anthropic-0.25.2-py3-none-any.whl", hash = "sha256:f854030b11052f7cbb5257be6134c8a8f25aa538f73013260e12238ff94234a3"}, + {file = "anthropic-0.25.2.tar.gz", hash = "sha256:cdf30ac234e3c0b305307399a6bb5dba45881adcb188d88fdf59802f90f15d6d"}, ] [package.dependencies] @@ -131,6 +131,7 @@ tokenizers = ">=0.13.0" typing-extensions = ">=4.7,<5" [package.extras] +bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] vertex = ["google-auth (>=2,<3)"] [[package]] @@ -321,8 +322,8 @@ files = [ jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" urllib3 = [ - {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}, + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, ] [package.extras] @@ -733,6 +734,17 @@ files = [ marshmallow = ">=3.18.0,<4.0.0" typing-inspect = ">=0.4.0,<1" +[[package]] +name = "defusedxml" +version = "0.7.1" +description = "XML bomb protection for Python stdlib modules" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, + {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, +] + [[package]] name = "deprecated" version = "1.2.14" @@ -771,6 +783,17 @@ files = [ {file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"}, ] +[[package]] +name = "docstring-parser" +version = "0.16" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"}, + {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, +] + [[package]] name = "duckduckgo-search" version = "5.3.0" @@ -1017,12 +1040,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1057,40 +1080,47 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] [[package]] name = "google-cloud-aiplatform" -version = "1.40.0" +version = "1.47.0" description = "Vertex AI API client library" optional = false python-versions = ">=3.8" files = [ - {file = "google-cloud-aiplatform-1.40.0.tar.gz", hash = "sha256:1ee9aff2fa27c6852558a2abeaf0ffe0537bff90c5dc9f0e967762ac17291001"}, - {file = "google_cloud_aiplatform-1.40.0-py2.py3-none-any.whl", hash = "sha256:9c67a2664e138387ea82d70dec4b54e081b7de6e1089ed23fdaf66900d00320a"}, + {file = "google-cloud-aiplatform-1.47.0.tar.gz", hash = "sha256:1c4537db09b83957bf0623fd2afb37e339f89a3afcda3efce9dce79b16ab59c7"}, + {file = "google_cloud_aiplatform-1.47.0-py2.py3-none-any.whl", hash = "sha256:454ef0c44ecaeadcffe58f565acfce49e53895fd51bb20da8af0d48202a4cb21"}, ] [package.dependencies] -google-api-core = {version = ">=1.32.0,<2.0.dev0 || >=2.8.dev0,<3.0.0dev", extras = ["grpc"]} -google-cloud-bigquery = ">=1.15.0,<4.0.0dev" +docstring-parser = "<1" +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.8.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<3.0.0dev" +google-cloud-bigquery = ">=1.15.0,<3.20.0 || >3.20.0,<4.0.0dev" google-cloud-resource-manager = ">=1.3.3,<3.0.0dev" google-cloud-storage = ">=1.32.0,<3.0.0dev" packaging = ">=14.3" proto-plus = ">=1.22.0,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" +pydantic = "<3" shapely = "<3.0.0dev" [package.extras] autologging = ["mlflow (>=1.27.0,<=2.1.1)"] cloud-profiler = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"] -datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)"] +datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)"] endpoint = ["requests (>=2.28.1)"] -full = ["cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)", "requests (>=2.28.1)", "starlette (>=0.17.1)", "tensorflow (>=2.3.0,<2.15.0)", "tensorflow (>=2.3.0,<3.0.0dev)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)"] +full = ["cloudpickle (<3.0)", "cloudpickle (>=2.2.1,<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pydantic (<3)", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "starlette (>=0.17.1)", "tensorflow (>=2.3.0,<2.15.0)", "tensorflow (>=2.3.0,<3.0.0dev)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)"] +langchain = ["langchain (>=0.1.13,<0.2)", "langchain-core (<0.2)", "langchain-google-vertexai (<0.2)"] lit = ["explainable-ai-sdk (>=1.0.0)", "lit-nlp (==0.4.0)", "pandas (>=1.0.0)", "tensorflow (>=2.3.0,<3.0.0dev)"] metadata = ["numpy (>=1.15.0)", "pandas (>=1.0.0)"] pipelines = ["pyyaml (==5.3.1)"] prediction = ["docker (>=5.0.3)", "fastapi (>=0.71.0,<0.103.1)", "httpx (>=0.23.0,<0.25.0)", "starlette (>=0.17.1)", "uvicorn[standard] (>=0.16.0)"] preview = ["cloudpickle (<3.0)", "google-cloud-logging (<4.0)"] private-endpoints = ["requests (>=2.28.1)", "urllib3 (>=1.21.1,<1.27)"] -ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "pandas (>=1.0.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)"] +rapid-evaluation = ["nest-asyncio (>=1.0.0,<1.6.0)", "pandas (>=1.0.0,<2.2.0)"] +ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)"] +ray-testing = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pytest-xdist", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "ray[train] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "scikit-learn", "tensorflow", "torch (>=2.0.0,<2.1.0)", "xgboost", "xgboost-ray"] +reasoningengine = ["cloudpickle (>=2.2.1,<3.0)", "pydantic (<3)"] tensorboard = ["tensorflow (>=2.3.0,<2.15.0)"] -testing = ["bigframes", "cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<2.15.0)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<=2.12.0)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost", "xgboost-ray"] +testing = ["bigframes", "cloudpickle (<3.0)", "cloudpickle (>=2.2.1,<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-api-core (>=2.11,<3.0.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pydantic (<3)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (==2.13.0)", "tensorflow (==2.16.1)", "tensorflow (>=2.3.0,<2.15.0)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "torch (>=2.2.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost"] vizier = ["google-vizier (>=0.1.6)"] xai = ["tensorflow (>=2.3.0,<3.0.0dev)"] @@ -1699,6 +1729,22 @@ openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"] qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"] text-helpers = ["chardet (>=5.1.0,<6.0.0)"] +[[package]] +name = "langchain-anthropic" +version = "0.1.8" +description = "An integration package connecting AnthropicMessages and LangChain" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_anthropic-0.1.8-py3-none-any.whl", hash = "sha256:634eda00a1b2f4dc9bc59f35b6593483dd845c898af7ae491f91fb9ed871dc2b"}, + {file = "langchain_anthropic-0.1.8.tar.gz", hash = "sha256:e3e03dcc25338797a867705b296faba910243559c37a517992586d866b363bb3"}, +] + +[package.dependencies] +anthropic = ">=0.23.0,<1" +defusedxml = ">=0.7.1,<0.8.0" +langchain-core = ">=0.1.42,<0.2.0" + [[package]] name = "langchain-community" version = "0.0.29" @@ -1727,13 +1773,13 @@ extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15. [[package]] name = "langchain-core" -version = "0.1.39" +version = "0.1.42" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_core-0.1.39-py3-none-any.whl", hash = "sha256:26b024ef49c5a712611941651bff66fb9a2fd7bc82bd815934c94f0ecf9b6f03"}, - {file = "langchain_core-0.1.39.tar.gz", hash = "sha256:a34bd517dcd9b7e80adf131ee47554736f9532e1bba17593cd0a316a38ec2caf"}, + {file = "langchain_core-0.1.42-py3-none-any.whl", hash = "sha256:c5653ffa08a44f740295c157a24c0def4a753333f6a2c41f76bf431cd00be8b5"}, + {file = "langchain_core-0.1.42.tar.gz", hash = "sha256:40751bf60ea5d8e2b2efe65290db434717ee3834870c002e40e2811f09d814e6"}, ] [package.dependencies] @@ -1749,35 +1795,38 @@ extended-testing = ["jinja2 (>=3,<4)"] [[package]] name = "langchain-google-vertexai" -version = "0.0.3" -description = "An integration package connecting GoogleVertexAI and LangChain" +version = "1.0.1" +description = "An integration package connecting Google VertexAI and LangChain" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_google_vertexai-0.0.3-py3-none-any.whl", hash = "sha256:6090d9f925e579b553a3f3fe30c715d45f5201398c863c71a3ed783ced9f6943"}, - {file = "langchain_google_vertexai-0.0.3.tar.gz", hash = "sha256:74ba72274057e1d384a867754513471e8361dbce438cc604df5514807d0ae9d6"}, + {file = "langchain_google_vertexai-1.0.1-py3-none-any.whl", hash = "sha256:29dc243098a6a5a6972578bc5543c281b871c772a9969abb905d31a4ea39d019"}, + {file = "langchain_google_vertexai-1.0.1.tar.gz", hash = "sha256:a3eb99f1001181f5fa6ccb95c28cd8a3202379775cc7366f5c64e5421a537482"}, ] [package.dependencies] -google-cloud-aiplatform = ">=1.39.0,<2.0.0" +google-cloud-aiplatform = ">=1.47.0,<2.0.0" google-cloud-storage = ">=2.14.0,<3.0.0" -langchain-core = ">=0.1.7,<0.2" +langchain-core = ">=0.1.42,<0.2.0" types-protobuf = ">=4.24.0.4,<5.0.0.0" types-requests = ">=2.31.0,<3.0.0" +[package.extras] +anthropic = ["anthropic[vertexai] (>=0.23.0,<1)"] + [[package]] name = "langchain-openai" -version = "0.1.1" +version = "0.1.3" description = "An integration package connecting OpenAI and LangChain" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_openai-0.1.1-py3-none-any.whl", hash = "sha256:5cf4df5d2550af673337eafedaeec014ba52f9a25aeb8451206ca254bed01e5c"}, - {file = "langchain_openai-0.1.1.tar.gz", hash = "sha256:d10e9a9fc4c8ea99ca98f23808ce44c7dcdd65354ac07ad10afe874ecf3401ca"}, + {file = "langchain_openai-0.1.3-py3-none-any.whl", hash = "sha256:fa1f27815649291447e5370cb08e2f5a84e5c7c6121d0c055a6e296bd16d1e47"}, + {file = "langchain_openai-0.1.3.tar.gz", hash = "sha256:7f6e377d6bf88d6c2b1969fe5eecc1326271757512739e2f17c855cd7af53345"}, ] [package.dependencies] -langchain-core = ">=0.1.33,<0.2.0" +langchain-core = ">=0.1.42,<0.2.0" openai = ">=1.10.0,<2.0.0" tiktoken = ">=0.5.2,<1" @@ -1830,17 +1879,17 @@ six = "*" [[package]] name = "langgraph" -version = "0.0.31" +version = "0.0.37" description = "langgraph" optional = false python-versions = "<4.0,>=3.9.0" files = [ - {file = "langgraph-0.0.31-py3-none-any.whl", hash = "sha256:4dfa5d424b0330bfcf2e077acbad545a924672af8b669225f2480b2f84afa710"}, - {file = "langgraph-0.0.31.tar.gz", hash = "sha256:45b95c19dc4e66c4ef67678538070f77199d5f4673723ca66b84299e74693198"}, + {file = "langgraph-0.0.37-py3-none-any.whl", hash = "sha256:2a3c366353ee3380ce5c88f6dceefb662652ec4bcda43ced7c5efc50c635617f"}, + {file = "langgraph-0.0.37.tar.gz", hash = "sha256:fbed1906cbee764f644bb9c71352fa7f029bf3436caf837ca26de8a5ca21ce27"}, ] [package.dependencies] -langchain-core = ">=0.1.38,<0.2.0" +langchain-core = ">=0.1.42,<0.2.0" [[package]] name = "langserve" @@ -4168,4 +4217,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9.0,<3.12" -content-hash = "027126e2f06254070ba7fbcc6244eb5eb08eaf7caa4130f773abadb24ec584db" +content-hash = "2d8bfc1948b224006ffb1bf4dd36ab9aff037088d24648514a069d3f78a356dd" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 3f763819..d6b9a1bd 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,10 +20,10 @@ orjson = "^3.9.10" python-multipart = "^0.0.6" tiktoken = "^0.5.1" langchain = ">=0.0.338" -langgraph = "^0.0.31" +langgraph = "^0.0.37" pydantic = "<2.0" python-magic = "^0.4.27" -langchain-openai = "^0.1.1" +langchain-openai = "^0.1.3" beautifulsoup4 = "^4.12.3" boto3 = "^1.34.28" duckduckgo-search = "^5.3.0" @@ -31,19 +31,19 @@ arxiv = "^2.1.0" kay = "^0.1.2" xmltodict = "^0.13.0" wikipedia = "^1.4.0" -langchain-google-vertexai = "^0.0.3" +langchain-google-vertexai = "^1.0.1" setuptools = "^69.0.3" pdfminer-six = "^20231228" langchain-robocorp = "^0.0.5" fireworks-ai = "^0.11.2" -anthropic = "^0.13.0" httpx = { version = "0.25.2", extras = ["socks"] } unstructured = {extras = ["doc", "docx"], version = "^0.12.5"} pgvector = "^0.2.5" psycopg2-binary = "^2.9.9" asyncpg = "^0.29.0" -langchain-core = "^0.1.39" +langchain-core = "^0.1.42" pyjwt = {extras = ["crypto"], version = "^2.8.0"} +langchain-anthropic = "^0.1.8" [tool.poetry.group.dev.dependencies] uvicorn = "^0.23.2" diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index d4a53f3b..84248836 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -48,7 +48,6 @@ function App(props: { edit?: boolean }) { ? [ { content: message.message, - additional_kwargs: {}, type: "human", example: false, id: `human-${Math.random()}`, diff --git a/frontend/src/api/threads.ts b/frontend/src/api/threads.ts index 4e4e29c3..ffa1b879 100644 --- a/frontend/src/api/threads.ts +++ b/frontend/src/api/threads.ts @@ -1,4 +1,4 @@ -import { Chat } from "../hooks/useChatList.ts"; +import { Chat } from "../types"; export async function getThread(threadId: string): Promise { try { diff --git a/frontend/src/components/Chat.tsx b/frontend/src/components/Chat.tsx index b5a29ebf..796d99ed 100644 --- a/frontend/src/components/Chat.tsx +++ b/frontend/src/components/Chat.tsx @@ -2,7 +2,7 @@ import { useEffect, useRef } from "react"; import { StreamStateProps } from "../hooks/useStreamState"; import { useChatMessages } from "../hooks/useChatMessages"; import TypingBox from "./TypingBox"; -import { Message } from "./Message"; +import { MessageViewer } from "./Message"; import { ArrowDownCircleIcon } from "@heroicons/react/24/outline"; import { MessageWithFiles } from "../utils/formTypes.ts"; import { useParams } from "react-router-dom"; @@ -51,7 +51,7 @@ export function Chat(props: ChatProps) { return (
{messages?.map((msg, i) => ( - TYPES[id as keyof typeof TYPES]) ?? []; return (
-
+
-
+
@@ -477,6 +488,21 @@ const ORDER = [ "agent_type", ]; +function assignDefaults( + config: ConfigInterface["config"] | undefined | null, + configDefaults: Schemas["configDefaults"], +) { + return config + ? { + ...config, + configurable: { + ...configDefaults?.configurable, + ...config.configurable, + }, + } + : configDefaults; +} + export function Config(props: { className?: string; configSchema: Schemas["configSchema"]; @@ -487,7 +513,7 @@ export function Config(props: { edit?: boolean; }) { const [values, setValues] = useState( - props.config?.config ?? props.configDefaults, + assignDefaults(props.config?.config, props.configDefaults), ); const [selectedTools, setSelectedTools] = useState([]); const typeKey = "type"; @@ -526,7 +552,7 @@ export function Config(props: { }; useEffect(() => { - setValues(props.config?.config ?? props.configDefaults); + setValues(assignDefaults(props.config?.config, props.configDefaults)); }, [props.config, props.configDefaults]); useEffect(() => { if (dropzone.acceptedFiles.length > 0) { diff --git a/frontend/src/components/Document.tsx b/frontend/src/components/Document.tsx index 4960102b..c2b194c3 100644 --- a/frontend/src/components/Document.tsx +++ b/frontend/src/components/Document.tsx @@ -1,13 +1,25 @@ import { useMemo, useState } from "react"; -import { cn } from "../utils/cn"; import { ChevronDownIcon, ChevronRightIcon } from "@heroicons/react/24/outline"; +import { cn } from "../utils/cn"; +import { MessageDocument } from "../types"; +import { StringViewer } from "./String"; + +function isValidHttpUrl(str: string) { + let url; + + try { + url = new URL(str); + } catch (_) { + return false; + } -export interface PageDocument { - page_content: string; - metadata: Record; + return url.protocol === "http:" || url.protocol === "https:"; } -function PageDocument(props: { document: PageDocument; className?: string }) { +function DocumentViewer(props: { + document: MessageDocument; + className?: string; +}) { const [open, setOpen] = useState(false); const metadata = useMemo(() => { @@ -46,28 +58,28 @@ function PageDocument(props: { document: PageDocument; className?: string }) { )} onClick={() => setOpen(true)} > - - - {props.document.page_content.trim().replace(/\n/g, " ")} - + + ); } return ( - - - {props.document.page_content} - + {metadata.map(({ key, value }, idx) => { @@ -77,22 +89,28 @@ function PageDocument(props: { document: PageDocument; className?: string }) { key={idx} > {key} - {value} + {isValidHttpUrl(value) ? ( + + {value} + + ) : ( + {value} + )} ); })} - +
); } -export function DocumentList(props: { documents: PageDocument[] }) { +export function DocumentList(props: { documents: MessageDocument[] }) { return (
{props.documents.map((document, idx) => ( - + ))}
diff --git a/frontend/src/components/Message.tsx b/frontend/src/components/Message.tsx index d691b449..7e66fec2 100644 --- a/frontend/src/components/Message.tsx +++ b/frontend/src/components/Message.tsx @@ -1,41 +1,76 @@ import { memo, useState } from "react"; -import { Message as MessageType } from "../hooks/useChatList"; +import { MessageDocument, Message as MessageType, ToolCall } from "../types"; import { str } from "../utils/str"; import { cn } from "../utils/cn"; -import { marked } from "marked"; -import DOMPurify from "dompurify"; import { ChevronDownIcon } from "@heroicons/react/24/outline"; import { LangSmithActions } from "./LangSmithActions"; import { DocumentList } from "./Document"; +import { omit } from "lodash"; +import { StringViewer } from "./String"; -function tryJsonParse(value: string) { - try { - return JSON.parse(value); - } catch (e) { - return {}; - } +function ToolRequest( + props: ToolCall & { + open?: boolean; + setOpen?: (open: boolean) => void; + }, +) { + return ( + <> + + Use + + {props.name && ( + + {props.name} + + )} + {props.args && ( +
+
+ + + {Object.entries(props.args).map(([key, value], i) => ( + + + + + ))} + +
+
{key}
+
+ {str(value)} +
+
+
+ )} + + ); } -function Function(props: { - call: boolean; +function ToolResponse(props: { name?: string; - args?: string; open?: boolean; setOpen?: (open: boolean) => void; }) { return ( <> - {props.call && ( - - Use - - )} {props.name && ( {props.name} )} - {!props.call && props.setOpen && ( + {props.setOpen && ( )} - {props.args && ( -
-
- - - {Object.entries(tryJsonParse(props.args)).map( - ([key, value], i) => ( - - - - - ), - )} - -
-
{key}
-
- {str(value)} -
-
-
- )} ); } -export const Message = memo(function Message( +function isDocumentContent( + content: MessageType["content"], +): content is MessageDocument[] { + return ( + Array.isArray(content) && + content.every((d) => typeof d === "object" && !!d && !!d.page_content) + ); +} + +export function MessageContent(props: { content: MessageType["content"] }) { + if (typeof props.content === "string") { + return ; + } else if (isDocumentContent(props.content)) { + return ; + } else if ( + Array.isArray(props.content) && + props.content.every( + (it) => typeof it === "object" && !!it && typeof it.content === "string", + ) + ) { + return ( + ({ + page_content: it.content, + metadata: omit(it, "content"), + }))} + /> + ); + } else { + let content = props.content; + if (Array.isArray(content)) { + content = content.filter((it) => + typeof it === "object" && !!it && "type" in it + ? it.type !== "tool_use" + : true, + ); + } + if (Array.isArray(content) ? content.length === 0 : !content) { + return null; + } + return
{str(content)}
; + } +} + +export const MessageViewer = memo(function ( props: MessageType & { runId?: string }, ) { const [open, setOpen] = useState(false); const contentIsDocuments = ["function", "tool"].includes(props.type) && - Array.isArray(props.content) && - props.content.every((d) => !!d.page_content); + isDocumentContent(props.content); + const showContent = + ["function", "tool"].includes(props.type) && !contentIsDocuments + ? open + : true; return (
@@ -109,51 +159,16 @@ export const Message = memo(function Message(
{["function", "tool"].includes(props.type) && ( - )} - {props.additional_kwargs?.function_call && ( - - )} - {props.additional_kwargs?.tool_calls - ?.filter((call) => call.function) - ?.map((call) => ( - - ))} - {( - ["function", "tool"].includes(props.type) && !contentIsDocuments - ? open - : true - ) ? ( - typeof props.content === "string" ? ( -
- ) : contentIsDocuments ? ( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - - ) : ( -
{str(props.content)}
- ) - ) : ( - false - )} + {props.tool_calls?.map((call) => ( + + ))} + {showContent && }
{props.runId && ( diff --git a/frontend/src/components/String.tsx b/frontend/src/components/String.tsx new file mode 100644 index 00000000..b547e834 --- /dev/null +++ b/frontend/src/components/String.tsx @@ -0,0 +1,19 @@ +import { MarkedOptions, marked } from "marked"; +import DOMPurify from "dompurify"; +import { cn } from "../utils/cn"; + +const OPTIONS: MarkedOptions = { + gfm: true, + breaks: true, +}; + +export function StringViewer(props: { value: string; className?: string }) { + return ( +
+ ); +} diff --git a/frontend/src/components/TypingBox.tsx b/frontend/src/components/TypingBox.tsx index c369c9d6..df463e05 100644 --- a/frontend/src/components/TypingBox.tsx +++ b/frontend/src/components/TypingBox.tsx @@ -12,7 +12,7 @@ import { useDropzone } from "react-dropzone"; import { MessageWithFiles } from "../utils/formTypes.ts"; import { DROPZONE_CONFIG, TYPE_NAME } from "../constants.ts"; import { Config } from "../hooks/useConfigList.ts"; -import { Chat } from "../hooks/useChatList.ts"; +import { Chat } from "../types"; function getFileTypeIcon(fileType: string) { switch (fileType) { diff --git a/frontend/src/hooks/useChatList.ts b/frontend/src/hooks/useChatList.ts index 2e3d4123..b5a51bb6 100644 --- a/frontend/src/hooks/useChatList.ts +++ b/frontend/src/hooks/useChatList.ts @@ -1,37 +1,6 @@ import { useCallback, useEffect, useReducer } from "react"; import orderBy from "lodash/orderBy"; - -export interface Message { - id: string; - type: string; - content: - | string - | { page_content: string; metadata: Record }[] - | object; - name?: string; - additional_kwargs?: { - name?: string; - function_call?: { - name?: string; - arguments?: string; - }; - tool_calls?: { - id: string; - function?: { - name?: string; - arguments?: string; - }; - }[]; - }; - example: boolean; -} - -export interface Chat { - assistant_id: string; - thread_id: string; - name: string; - updated_at: string; -} +import { Chat } from "../types"; export interface ChatListProps { chats: Chat[] | null; diff --git a/frontend/src/hooks/useChatMessages.ts b/frontend/src/hooks/useChatMessages.ts index 28830af3..40db07ed 100644 --- a/frontend/src/hooks/useChatMessages.ts +++ b/frontend/src/hooks/useChatMessages.ts @@ -1,5 +1,5 @@ import { useEffect, useMemo, useRef, useState } from "react"; -import { Message } from "./useChatList"; +import { Message } from "../types"; import { StreamState, mergeMessagesById } from "./useStreamState"; async function getState(threadId: string) { diff --git a/frontend/src/hooks/useStreamState.tsx b/frontend/src/hooks/useStreamState.tsx index 960f2c26..f0a4baee 100644 --- a/frontend/src/hooks/useStreamState.tsx +++ b/frontend/src/hooks/useStreamState.tsx @@ -1,6 +1,6 @@ import { useCallback, useState } from "react"; import { fetchEventSource } from "@microsoft/fetch-event-source"; -import { Message } from "./useChatList"; +import { Message } from "../types"; export interface StreamState { status: "inflight" | "error" | "done"; diff --git a/frontend/src/types.ts b/frontend/src/types.ts new file mode 100644 index 00000000..b1bd0de2 --- /dev/null +++ b/frontend/src/types.ts @@ -0,0 +1,26 @@ +export interface ToolCall { + id: string; + name: string; + args: Record; +} + +export interface MessageDocument { + page_content: string; + metadata: Record; +} + +export interface Message { + id: string; + type: string; + content: string | MessageDocument[] | object; + name?: string; + tool_calls?: ToolCall[]; + example: boolean; +} + +export interface Chat { + assistant_id: string; + thread_id: string; + name: string; + updated_at: string; +}