Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed May 4, 2024
1 parent 0774354 commit 68ef57d
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions predicators/pretrained_model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def _sample_completions(


class OpenAIVLM(VisionLanguageModel):
"""Interface for OpenAI's VLMs, including GPT-4 Turbo (and preview versions)."""
"""Interface for OpenAI's VLMs, including GPT-4 Turbo (and preview
versions)."""

def __init__(self, model_name: str = "gpt-4-turbo", detail: str = "auto"):
"""Initialize with a specific model name."""
Expand All @@ -271,11 +272,12 @@ def set_openai_key(self, key: Optional[str] = None):
key = os.environ["OPENAI_API_KEY"]
openai.api_key = key

def prepare_vision_messages(
self, images: List[PIL.Image.Image],
prefix: Optional[str] = None, suffix: Optional[str] = None, image_size: Optional[int] = 512,
detail: str = "auto"
):
def prepare_vision_messages(self,
images: List[PIL.Image.Image],
prefix: Optional[str] = None,
suffix: Optional[str] = None,
image_size: Optional[int] = 512,
detail: str = "auto"):
"""Prepare text and image messages for the OpenAI API."""
content = []

Expand All @@ -291,7 +293,8 @@ def prepare_vision_messages(
img_resized = img
if image_size:
factor = image_size / max(img.size)
img_resized = img.resize((int(img.size[0] * factor), int(img.size[1] * factor)))
img_resized = img.resize(
(int(img.size[0] * factor), int(img.size[1] * factor)))

# Convert the image to PNG format and encode it in base64
buffer = BytesIO()
Expand All @@ -312,9 +315,15 @@ def prepare_vision_messages(

return [{"role": "user", "content": content}]

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def call_openai_api(self, messages: list, model: str = "gpt-4", seed: Optional[int] = None, max_tokens: int = 32,
temperature: float = 0.2, verbose: bool = False):
@retry(wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6))
def call_openai_api(self,
messages: list,
model: str = "gpt-4",
seed: Optional[int] = None,
max_tokens: int = 32,
temperature: float = 0.2,
verbose: bool = False):
"""Make an API call to OpenAI."""
client = openai.OpenAI()
completion = client.chat.completions.create(
Expand All @@ -334,19 +343,24 @@ def get_id(self) -> str:
return f"OpenAI-{self.model_name}"

def _sample_completions(
self,
prompt: str,
imgs: Optional[List[PIL.Image.Image]],
temperature: float,
seed: int,
stop_token: Optional[str] = None,
num_completions: int = 1,
max_tokens=512,
self,
prompt: str,
imgs: Optional[List[PIL.Image.Image]],
temperature: float,
seed: int,
stop_token: Optional[str] = None,
num_completions: int = 1,
max_tokens=512,
) -> List[str]:
"""Query the model and get responses."""
messages = self.prepare_vision_messages(prefix=prompt, images=imgs, detail="auto")
messages = self.prepare_vision_messages(prefix=prompt,
images=imgs,
detail="auto")
responses = [
self.call_openai_api(messages, model=self.model_name, max_tokens=max_tokens, temperature=temperature)
self.call_openai_api(messages,
model=self.model_name,
max_tokens=max_tokens,
temperature=temperature)
for _ in range(num_completions)
]
return responses

0 comments on commit 68ef57d

Please sign in to comment.