From 04de67acf485b26f027b9b739e91e66fe142fd56 Mon Sep 17 00:00:00 2001 From: "P. Taylor Goetz" Date: Wed, 15 May 2024 18:12:05 -0400 Subject: [PATCH] Add GPT 4o as a model --- backend/app/agent.py | 13 +++++++++++-- backend/app/llms.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/backend/app/agent.py b/backend/app/agent.py index 2658c9b7..f14cb206 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -64,6 +64,7 @@ class AgentType(str, Enum): GPT_35_TURBO = "GPT 3.5 Turbo" GPT_4 = "GPT 4 Turbo" + GPT_4O = "GPT 4o" AZURE_OPENAI = "GPT 4 (Azure OpenAI)" CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" @@ -88,7 +89,12 @@ def get_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) elif agent == AgentType.GPT_4: - llm = get_openai_llm(gpt_4=True) + llm = get_openai_llm(model="gpt-4-turbo") + return get_tools_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) + elif agent == AgentType.GPT_4O: + llm = get_openai_llm(model="gpt-4o") return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) @@ -182,6 +188,7 @@ def __init__( class LLMType(str, Enum): GPT_35_TURBO = "GPT 3.5 Turbo" GPT_4 = "GPT 4 Turbo" + GPT_4O = "GPT 4o" AZURE_OPENAI = "GPT 4 (Azure OpenAI)" CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" @@ -277,7 +284,9 @@ def __init__( if llm_type == LLMType.GPT_35_TURBO: llm = get_openai_llm() elif llm_type == LLMType.GPT_4: - llm = get_openai_llm(gpt_4=True) + llm = get_openai_llm(model="gpt-4-turbo") + elif llm_type == LLMType.GPT_4O: + llm = get_openai_llm(model="gpt-4o") elif llm_type == LLMType.AZURE_OPENAI: llm = get_openai_llm(azure=True) elif llm_type == LLMType.CLAUDE2: diff --git a/backend/app/llms.py b/backend/app/llms.py index 2e4f35b9..94d5fc10 100644 --- a/backend/app/llms.py +++ b/backend/app/llms.py @@ -15,7 +15,7 @@ @lru_cache(maxsize=4) -def get_openai_llm(gpt_4: bool = False, azure: bool = False): +def get_openai_llm(model: str = "gpt-3.5-turbo", azure: bool = False): proxy_url = os.getenv("PROXY_URL") http_client = None if proxy_url: @@ -27,7 +27,7 @@ def get_openai_llm(gpt_4: bool = False, azure: bool = False): if not azure: try: - openai_model = "gpt-4-turbo-preview" if gpt_4 else "gpt-3.5-turbo" + openai_model = model llm = ChatOpenAI( http_client=http_client, model=openai_model,