From a25a7940af785a83236591699927b85a0576cc82 Mon Sep 17 00:00:00 2001 From: "Le D. Hoa" Date: Sun, 17 Mar 2024 13:49:49 +0700 Subject: [PATCH] Add ChatGPT fine-tuned model option --- .env.example | 3 +++ backend/app/agent.py | 12 +++++++++++- backend/app/llms.py | 8 +++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index 9e476bf9..614b8d47 100644 --- a/.env.example +++ b/.env.example @@ -9,3 +9,6 @@ AZURE_OPENAI_API_VERSION=placeholder CONNERY_RUNNER_URL=https://your-personal-connery-runner-url CONNERY_RUNNER_API_KEY=placeholder PROXY_URL=your_proxy_url + +# (optional) Custom/fine-tune chatGPT model name +FINE_TUNED_GPT_MODEL=placeholder diff --git a/backend/app/agent.py b/backend/app/agent.py index 92f2640a..14f95007 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -1,3 +1,5 @@ + +import os from enum import Enum from typing import Any, Mapping, Optional, Sequence, Union @@ -63,7 +65,7 @@ class AgentType(str, Enum): CLAUDE2 = "Claude 2" BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" GEMINI = "GEMINI" - + FINE_TUNED_GPT_MODEL = 'Fine-Tuned GPT' DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." @@ -106,6 +108,11 @@ def get_agent_executor( return get_google_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) + elif agent == AgentType.FINE_TUNED_GPT_MODEL: + llm = get_openai_llm(fine_tuned_gpt=True) + return get_openai_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) else: raise ValueError("Unexpected agent type") @@ -175,6 +182,7 @@ class LLMType(str, Enum): BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" GEMINI = "GEMINI" MIXTRAL = "Mixtral" + FINE_TUNED_GPT = "Fine-Tuned GPT" def get_chatbot( @@ -195,6 +203,8 @@ def get_chatbot( llm = get_google_llm() elif llm_type == LLMType.MIXTRAL: llm = get_mixtral_fireworks() + elif llm_type == LLMType.FINE_TUNED_GPT: + llm = get_openai_llm(fine_tuned_gpt=True) else: raise ValueError("Unexpected llm type") return get_chatbot_executor(llm, system_message, CHECKPOINTER) diff --git a/backend/app/llms.py b/backend/app/llms.py index bb58acb1..6de312de 100644 --- a/backend/app/llms.py +++ b/backend/app/llms.py @@ -13,7 +13,7 @@ @lru_cache(maxsize=4) -def get_openai_llm(gpt_4: bool = False, azure: bool = False): +def get_openai_llm(gpt_4: bool = False, azure: bool = False, fine_tuned_gpt: bool = False): proxy_url = os.getenv("PROXY_URL") http_client = None if proxy_url: @@ -31,6 +31,12 @@ def get_openai_llm(gpt_4: bool = False, azure: bool = False): temperature=0, streaming=True, ) + elif fine_tuned_gpt: + llm = ChatOpenAI( + model=os.environ["FINE_TUNED_GPT_MODEL"], + temperature=0, + streaming=True + ) else: llm = ChatOpenAI( http_client=http_client,