Skip to content

Commit

Permalink
Migrate pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
lgesuellip committed Oct 24, 2024
1 parent 541ae6f commit cbbda36
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 128 deletions.
6 changes: 3 additions & 3 deletions backend/app/api/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""

name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")
name: Annotated[str, Field(description="The name of the assistant.")]
config: Annotated[dict, Field(description="The assistant config.")]
public: Annotated[bool, Field(default=False, description="Whether the assistant is public.")]


AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
Expand Down
10 changes: 5 additions & 5 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import langsmith.client
from fastapi import APIRouter, BackgroundTasks, HTTPException
from fastapi.exceptions import RequestValidationError
from langchain.pydantic_v1 import ValidationError
from pydantic import ValidationError
from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
from langsmith.utils import tracing_is_enabled
Expand Down Expand Up @@ -51,7 +51,7 @@ async def _run_input_and_config(payload: CreateRunPayload, user_id: str):

try:
if payload.input is not None:
agent.get_input_schema(config).validate(payload.input)
agent.get_input_schema(config)(**payload.input)
except ValidationError as e:
raise RequestValidationError(e.errors(), body=payload)

Expand Down Expand Up @@ -84,19 +84,19 @@ async def stream_run(
@router.get("/input_schema")
async def input_schema() -> dict:
"""Return the input schema of the runnable."""
return agent.get_input_schema().schema()
return agent.get_input_schema().model_json_schema()


@router.get("/output_schema")
async def output_schema() -> dict:
"""Return the output schema of the runnable."""
return agent.get_output_schema().schema()
return agent.get_output_schema().model_json_schema()


@router.get("/config_schema")
async def config_schema() -> dict:
"""Return the config schema of the runnable."""
return agent.config_schema().schema()
return agent.config_schema().model_json_schema()


if tracing_is_enabled():
Expand Down
5 changes: 2 additions & 3 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

class ThreadPutRequest(BaseModel):
"""Payload for creating a thread."""

name: str = Field(..., description="The name of the thread.")
assistant_id: str = Field(..., description="The ID of the assistant to use.")
name: Annotated[str, Field(description="The name of the thread.")]
assistant_id: Annotated[str, Field(description="The ID of the assistant to use.")]


class ThreadPostRequest(BaseModel):
Expand Down
29 changes: 18 additions & 11 deletions backend/app/auth/settings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from base64 import b64decode
from enum import Enum
from typing import Optional, Union
from typing import Optional, Union, List

from pydantic import BaseSettings, root_validator, validator
from pydantic import ConfigDict, model_validator, field_validator
from pydantic_settings import BaseSettings


class AuthType(Enum):
Expand All @@ -16,27 +17,32 @@ class JWTSettingsBase(BaseSettings):
iss: str
aud: Union[str, list[str]]

@validator("aud", pre=True, always=True)
def set_aud(cls, v, values) -> Union[str, list[str]]:
return v.split(",") if "," in v else v

class Config:
env_prefix = "jwt_"
@field_validator("aud", mode="before")
@classmethod
def set_aud(cls, v) -> Union[str, List[str]]:
if isinstance(v, str) and "," in v:
return v.split(",")
return v
model_config = ConfigDict(env_prefix="jwt_",)


class JWTSettingsLocal(JWTSettingsBase):
decode_key_b64: str
decode_key: str = None
alg: str

@validator("decode_key", pre=True, always=True)
@field_validator("decode_key", mode="before")
@classmethod
def set_decode_key(cls, v, values):
"""
Key may be a multiline string (e.g. in the case of a public key), so to
be able to set it from env, we set it as a base64 encoded string and
decode it here.
"""
return b64decode(values["decode_key_b64"]).decode("utf-8")
decode_key_b64 = kwargs.get("decode_key_b64")
if decode_key_b64:
return b64decode(decode_key_b64).decode("utf-8")
return v


class JWTSettingsOIDC(JWTSettingsBase):
Expand All @@ -48,7 +54,8 @@ class Settings(BaseSettings):
jwt_local: Optional[JWTSettingsLocal] = None
jwt_oidc: Optional[JWTSettingsOIDC] = None

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_jwt_settings(cls, values):
auth_type = values.get("auth_type")
if auth_type == AuthType.JWT_LOCAL and values.get("jwt_local") is None:
Expand Down
161 changes: 65 additions & 96 deletions backend/app/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum
from functools import lru_cache
from typing import Optional
from typing import Annotated,Literal, Optional

from langchain.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
from langchain.tools.retriever import create_retriever_tool
from langchain_community.agent_toolkits.connery import ConneryToolkit
from langchain_community.retrievers.kay import KayAiRetriever
Expand All @@ -22,26 +22,24 @@
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langchain_core.tools import Tool
from langchain_robocorp import ActionServerToolkit
from typing_extensions import TypedDict

from app.upload import vstore


class DDGInput(BaseModel):
query: str = Field(description="search query to look up")
query: Annotated[str, Field(description="search query to look up")]


class ArxivInput(BaseModel):
query: str = Field(description="search query to look up")
query: Annotated[str, Field(description="search query to look up")]


class PythonREPLInput(BaseModel):
query: str = Field(description="python command to run")
query: Annotated[str, Field(description="python command to run")]


class DallEInput(BaseModel):
query: str = Field(description="image description to generate image from")
query: Annotated[str, Field(description="image description to generate image from")]


class AvailableTools(str, Enum):
Expand All @@ -66,9 +64,9 @@ class ToolConfig(TypedDict):

class BaseTool(BaseModel):
type: AvailableTools
name: Optional[str]
description: Optional[str]
config: Optional[ToolConfig]
name: Optional[str] = None
description: Optional[str] = None
config: Optional[ToolConfig] = None
multi_use: Optional[bool] = False


Expand All @@ -78,125 +76,107 @@ class ActionServerConfig(ToolConfig):


class ActionServer(BaseTool):
type: AvailableTools = Field(AvailableTools.ACTION_SERVER, const=True)
name: str = Field("Action Server by Sema4.ai", const=True)
description: str = Field(
(
type: Literal[AvailableTools.ACTION_SERVER] = AvailableTools.ACTION_SERVER
name: Literal["Action Server by Sema4.ai"] = "Action Server by Sema4.ai"
description: Literal[(
"Run AI actions with "
"[Sema4.ai Action Server](https://github.com/Sema4AI/actions)."
),
const=True,
)
)] = (
"Run AI actions with "
"[Sema4.ai Action Server](https://github.com/Sema4AI/actions)."
)
config: ActionServerConfig
multi_use: bool = Field(True, const=True)
multi_use: Literal[True] = True


class Connery(BaseTool):
type: AvailableTools = Field(AvailableTools.CONNERY, const=True)
name: str = Field("AI Action Runner by Connery", const=True)
description: str = Field(
(
type: Literal[AvailableTools.CONNERY] = AvailableTools.CONNERY
name: Literal["AI Action Runner by Connery"] = "AI Action Runner by Connery"
description: Literal[(
"Connect OpenGPTs to the real world with "
"[Connery](https://github.com/connery-io/connery)."
),
const=True,
)
)] = (
"Connect OpenGPTs to the real world with "
"[Connery](https://github.com/connery-io/connery)."
)


class DDGSearch(BaseTool):
type: AvailableTools = Field(AvailableTools.DDG_SEARCH, const=True)
name: str = Field("DuckDuckGo Search", const=True)
description: str = Field(
"Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/).",
const=True,
)
type: Literal[AvailableTools.DDG_SEARCH] = AvailableTools.DDG_SEARCH
name: Literal["DuckDuckGo Search"] = "DuckDuckGo Search"
description: Literal["Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/)."] = "Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/)."


class Arxiv(BaseTool):
type: AvailableTools = Field(AvailableTools.ARXIV, const=True)
name: str = Field("Arxiv", const=True)
description: str = Field("Searches [Arxiv](https://arxiv.org/).", const=True)
type: Literal[AvailableTools.ARXIV] = AvailableTools.ARXIV
name: Literal["Arxiv"] = "Arxiv"
description: Literal["Searches [Arxiv](https://arxiv.org/)."] = "Searches [Arxiv](https://arxiv.org/)."


class YouSearch(BaseTool):
type: AvailableTools = Field(AvailableTools.YOU_SEARCH, const=True)
name: str = Field("You.com Search", const=True)
description: str = Field(
"Uses [You.com](https://you.com/) search, optimized responses for LLMs.",
const=True,
)
type: Literal[AvailableTools.YOU_SEARCH] = AvailableTools.YOU_SEARCH
name: Literal["You.com Search"] = "You.com Search"
description: Literal["Uses [You.com](https://you.com/) search, optimized responses for LLMs."] = "Uses [You.com](https://you.com/) search, optimized responses for LLMs."


class SecFilings(BaseTool):
type: AvailableTools = Field(AvailableTools.SEC_FILINGS, const=True)
name: str = Field("SEC Filings (Kay.ai)", const=True)
description: str = Field(
"Searches through SEC filings using [Kay.ai](https://www.kay.ai/).", const=True
)
type: Literal[AvailableTools.SEC_FILINGS] = AvailableTools.SEC_FILINGS
name: Literal["SEC Filings (Kay.ai)"] = "SEC Filings (Kay.ai)"
description: Literal["Searches through SEC filings using [Kay.ai](https://www.kay.ai/)."] = "Searches through SEC filings using [Kay.ai](https://www.kay.ai/)."


class PressReleases(BaseTool):
type: AvailableTools = Field(AvailableTools.PRESS_RELEASES, const=True)
name: str = Field("Press Releases (Kay.ai)", const=True)
description: str = Field(
"Searches through press releases using [Kay.ai](https://www.kay.ai/).",
const=True,
)
type: Literal[AvailableTools.PRESS_RELEASES] = AvailableTools.PRESS_RELEASES
name: Literal["Press Releases (Kay.ai)"] = "Press Releases (Kay.ai)"
description: Literal["Searches through press releases using [Kay.ai](https://www.kay.ai/)."] = "Searches through press releases using [Kay.ai](https://www.kay.ai/)."


class PubMed(BaseTool):
type: AvailableTools = Field(AvailableTools.PUBMED, const=True)
name: str = Field("PubMed", const=True)
description: str = Field(
"Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/).", const=True
)
type: Literal[AvailableTools.PUBMED] = AvailableTools.PUBMED
name: Literal["PubMed"] = "PubMed"
description: Literal["Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/)."] = "Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/)."


class Wikipedia(BaseTool):
type: AvailableTools = Field(AvailableTools.WIKIPEDIA, const=True)
name: str = Field("Wikipedia", const=True)
description: str = Field(
"Searches [Wikipedia](https://pypi.org/project/wikipedia/).", const=True
)
type: Literal[AvailableTools.WIKIPEDIA] = AvailableTools.WIKIPEDIA
name: Literal["Wikipedia"] = "Wikipedia"
description: Literal["Searches [Wikipedia](https://pypi.org/project/wikipedia/)."] = "Searches [Wikipedia](https://pypi.org/project/wikipedia/)."


class Tavily(BaseTool):
type: AvailableTools = Field(AvailableTools.TAVILY, const=True)
name: str = Field("Search (Tavily)", const=True)
description: str = Field(
(
type: Literal[AvailableTools.TAVILY] = AvailableTools.TAVILY
name: Literal["Search (Tavily)"] = "Search (Tavily)"
description: Literal[(
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"Includes sources in the response."
),
const=True,
)
)] = (
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"Includes sources in the response."
)


class TavilyAnswer(BaseTool):
type: AvailableTools = Field(AvailableTools.TAVILY_ANSWER, const=True)
name: str = Field("Search (short answer, Tavily)", const=True)
description: str = Field(
(
type: Literal[AvailableTools.TAVILY_ANSWER] = AvailableTools.TAVILY_ANSWER
name: Literal["Search (short answer, Tavily)"] = "Search (short answer, Tavily)"
description: Literal[(
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"This returns only the answer, no supporting evidence."
),
const=True,
)
)] = (
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"This returns only the answer, no supporting evidence."
)


class Retrieval(BaseTool):
type: AvailableTools = Field(AvailableTools.RETRIEVAL, const=True)
name: str = Field("Retrieval", const=True)
description: str = Field("Look up information in uploaded files.", const=True)
type: Literal[AvailableTools.RETRIEVAL] = AvailableTools.RETRIEVAL
name: Literal["Retrieval"] = "Retrieval"
description: Literal["Look up information in uploaded files."] = "Look up information in uploaded files."


class DallE(BaseTool):
type: AvailableTools = Field(AvailableTools.DALL_E, const=True)
name: str = Field("Generate Image (Dall-E)", const=True)
description: str = Field(
"Generates images from a text description using OpenAI's DALL-E model.",
const=True,
)
type: Literal[AvailableTools.DALL_E] = AvailableTools.DALL_E
name: Literal["Generate Image (Dall-E)"] = "Generate Image (Dall-E)"
description: Literal["Generates images from a text description using OpenAI's DALL-E model."] = "Generates images from a text description using OpenAI's DALL-E model."


RETRIEVAL_DESCRIPTION = """Can be used to look up information that was uploaded to this assistant.
Expand Down Expand Up @@ -286,16 +266,6 @@ def _get_tavily_answer():
return _TavilyAnswer(api_wrapper=tavily_search, name="search_tavily_answer")


def _get_action_server(**kwargs: ActionServerConfig):
toolkit = ActionServerToolkit(
url=kwargs["url"],
api_key=kwargs["api_key"],
additional_headers=kwargs.get("additional_headers", {}),
)
tools = toolkit.get_tools()
return tools


@lru_cache(maxsize=1)
def _get_connery_actions():
connery_service = ConneryService()
Expand All @@ -314,7 +284,6 @@ def _get_dalle_tools():


TOOLS = {
AvailableTools.ACTION_SERVER: _get_action_server,
AvailableTools.CONNERY: _get_connery_actions,
AvailableTools.DDG_SEARCH: _get_duck_duck_go,
AvailableTools.ARXIV: _get_arxiv,
Expand Down
Loading

0 comments on commit cbbda36

Please sign in to comment.