Skip to content

Commit

Permalink
OpenAI integration in python (#10)
Browse files Browse the repository at this point in the history
* OpenAI sdk integration

* Fix variable and add more error context

* Add comment

* Remove fulfilled todo

* Default to using input data as document

* Spacing

* Return type and clearer spacing for error

* Update python/starpoint/db.py

Co-authored-by: Scott Wey <scottwey@users.noreply.github.com>

* Fix requirements

* Add tests for openapi init

* Fix tests from rebase

* Add transpose and insert functionality

* Add path checking for api key from file

* Add test and fix bug with client.transpose_and_insert

* Add tests for transpose and insert

* Fix bug with optional without default

* Fix requests mock name

* Remove unneeded check

* Add basic tests for openai integration

* More tests for openai integration

* Better testability and fix bug with listwise processing

* Test for write failure

* Fix missing coverage

* Negative test for infer schema

* Test coverage for infer schema ok response

---------

Co-authored-by: Scott Wey <scottwey@users.noreply.github.com>
  • Loading branch information
zapplecat and scottwey authored Jul 12, 2023
1 parent 11f9852 commit a325837
Show file tree
Hide file tree
Showing 3 changed files with 585 additions and 41 deletions.
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
validators~=0.20
requests~=2.28
openai~=0.27
137 changes: 136 additions & 1 deletion python/starpoint/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from typing import Any, Dict, List, Optional
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union
from uuid import UUID

import openai
import requests
import validators

Expand All @@ -24,7 +26,17 @@
MULTI_COLLECTION_VALUE_ERROR = (
"Please only provide either collection_id or collection_name in your request."
)
NO_API_KEY_VALUE_ERROR = "Please provide at least one value for either api_key or filepath where the api key lives."
MULTI_API_KEY_VALUE_ERROR = "Please only provide either api_key or filepath with the api_key in your initialization."
NO_API_KEY_FILE_ERROR = "The provided filepath for the API key is not a valid file."
SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid."
EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING = (
"The length of the embeddings and document_metadata provided are different. There may be a mismatch "
"between embeddings and the expected document metadata length; this may cause undesired collection insert or update."
)
NO_EMBEDDING_DATA_FOUND = (
"No embedding data found in the embedding response from OpenAI."
)


def _build_header(api_key: UUID, additional_headers: Optional[Dict[str, str]] = None):
Expand Down Expand Up @@ -173,6 +185,30 @@ def insert(
return {}
return response.json()

def transpose_and_insert(
self,
embeddings: List[float],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[UUID] = None,
collection_name: Optional[str] = None,
) -> Dict[Any, Any]:
if len(embeddings) != len(document_metadatas):
LOGGER.warning(EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING)

document = [
{
"embedding": embedding,
"metadata": document_metadata,
}
for embedding, document_metadata in zip(embeddings, document_metadatas)
]

self.insert(
document=document,
collection_id=collection_id,
collection_name=collection_name,
)

def update(
self,
documents: List[Dict[Any, Any]],
Expand Down Expand Up @@ -331,6 +367,9 @@ def __init__(
self.writer = Writer(api_key=api_key, host=writer_host)
self.reader = Reader(api_key=api_key, host=reader_host)

# Consider a wrapper around openai once this class gets bloated
self.openai = None

def delete(
self,
documents: List[UUID],
Expand All @@ -355,6 +394,20 @@ def insert(
collection_name=collection_name,
)

def transpose_and_insert(
self,
embeddings: List[float],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[UUID] = None,
collection_name: Optional[str] = None,
) -> Dict[Any, Any]:
return self.writer.transpose_and_insert(
embeddings=embeddings,
document_metadatas=document_metadatas,
collection_id=collection_id,
collection_name=collection_name,
)

def query(
self,
sql: Optional[str] = None,
Expand Down Expand Up @@ -392,3 +445,85 @@ def update(
collection_id=collection_id,
collection_name=collection_name,
)

"""
OpenAI convenience wrappers
"""

def init_openai(
self,
openai_key: Optional[str] = None,
openai_key_filepath: Optional[str] = None,
):
"""Initializes openai functionality"""
self.openai = openai
# TODO: maybe do this for starpoint api_key also

# If the init is unsuccessful, we deinitialize openai from this object in the except
try:
if openai_key and openai_key_filepath:
raise ValueError(MULTI_API_KEY_VALUE_ERROR)
elif openai_key is None:
if openai_key_filepath is None:
raise ValueError(NO_API_KEY_VALUE_ERROR)
if not Path(openai_key_filepath).is_file():
raise ValueError(NO_API_KEY_FILE_ERROR)
self.openai.api_key_path = openai_key_filepath
else:
self.openai.api_key = openai_key
except ValueError as e:
self.openai = None
raise e

def build_and_insert_embeddings_from_openai(
self,
model: str,
input_data: Union[str, Iterable],
document_metadatas: Optional[List[Dict]] = None,
collection_id: Optional[UUID] = None,
collection_name: Optional[str] = None,
openai_user: Optional[str] = None,
) -> Dict:
if self.openai is None:
raise RuntimeError(
"OpenAI instance has not been initialized. Please initialize it using "
"Client.init_openai()"
)

_check_collection_identifier_collision(collection_id, collection_name)

embedding_response = self.openai.Embedding.create(
model=model, input=input_data, user=openai_user
)

embedding_data = embedding_response.get("data")
if embedding_data is None:
LOGGER.warning(NO_EMBEDDING_DATA_FOUND)
return embedding_response

if document_metadatas is None:
LOGGER.info(
"No custom document_metadatas provided. Using input_data for the document_metadatas"
)
if isinstance(input_data, str):
document_metadatas = [{"input": input_data}]
else:
document_metadatas = [{"input": data} for data in input_data]

# Return the embedding response no matter what issues/bugs we might run into in the sdk
try:
sorted_embedding_data = sorted(embedding_data, key=lambda x: x["index"])
embeddings = map(lambda x: x.get("embedding"), sorted_embedding_data)
self.transpose_and_insert(
embeddings=embeddings,
document_metadatas=document_metadatas,
collection_id=collection_id,
collection_name=collection_name,
)
except Exception as e:
LOGGER.error(
"An exception has occurred while trying to load embeddings into the db. "
f"This is the error:\n{e}"
)

return embedding_response
Loading

0 comments on commit a325837

Please sign in to comment.