diff --git a/backend/app/agent.py b/backend/app/agent.py index 2658c9b7..f14cb206 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -64,6 +64,7 @@ class AgentType(str, Enum): GPT_35_TURBO = "GPT 3.5 Turbo" GPT_4 = "GPT 4 Turbo" + GPT_4O = "GPT 4o" AZURE_OPENAI = "GPT 4 (Azure OpenAI)" CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" @@ -88,7 +89,12 @@ def get_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) elif agent == AgentType.GPT_4: - llm = get_openai_llm(gpt_4=True) + llm = get_openai_llm(model="gpt-4-turbo") + return get_tools_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) + elif agent == AgentType.GPT_4O: + llm = get_openai_llm(model="gpt-4o") return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) @@ -182,6 +188,7 @@ def __init__( class LLMType(str, Enum): GPT_35_TURBO = "GPT 3.5 Turbo" GPT_4 = "GPT 4 Turbo" + GPT_4O = "GPT 4o" AZURE_OPENAI = "GPT 4 (Azure OpenAI)" CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" @@ -277,7 +284,9 @@ def __init__( if llm_type == LLMType.GPT_35_TURBO: llm = get_openai_llm() elif llm_type == LLMType.GPT_4: - llm = get_openai_llm(gpt_4=True) + llm = get_openai_llm(model="gpt-4-turbo") + elif llm_type == LLMType.GPT_4O: + llm = get_openai_llm(model="gpt-4o") elif llm_type == LLMType.AZURE_OPENAI: llm = get_openai_llm(azure=True) elif llm_type == LLMType.CLAUDE2: diff --git a/backend/app/llms.py b/backend/app/llms.py index 2e4f35b9..94d5fc10 100644 --- a/backend/app/llms.py +++ b/backend/app/llms.py @@ -15,7 +15,7 @@ @lru_cache(maxsize=4) -def get_openai_llm(gpt_4: bool = False, azure: bool = False): +def get_openai_llm(model: str = "gpt-3.5-turbo", azure: bool = False): proxy_url = os.getenv("PROXY_URL") http_client = None if proxy_url: @@ -27,7 +27,7 @@ def get_openai_llm(gpt_4: bool = False, azure: bool = False): if not azure: try: - openai_model = "gpt-4-turbo-preview" if gpt_4 else "gpt-3.5-turbo" + openai_model = model llm = ChatOpenAI( http_client=http_client, model=openai_model,