Skip to content

Commit

Permalink
Function name and docstring changes and add logging for mismatch lists
Browse files Browse the repository at this point in the history
  • Loading branch information
FullMetalMeowchemist committed Sep 14, 2023
1 parent 2df398d commit 908e71f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
19 changes: 14 additions & 5 deletions python/starpoint/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

# Error and warning messages
SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid."
TEXT_METADATA_LENGTH_MISMATCH_WARNING = (
"The length of the texts and metadatas provided are different. There may be a mismatch "
"between texts and the metadatas length; this may cause undesired results between the joining of "
"embeddings and metadatas."
)


class EmbeddingModel(Enum):
Expand All @@ -44,8 +49,8 @@ def embed(
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some texts creates embeddings using a model in starpoint. This is a
version of embed_and_join_metadata where joining metadata with the result is
not necessary. The same API is used between the two methods.
version of `embed_and_join_metadata_by_column` where joining metadata with the result is
not necessary. The same API is used for the two methods.
Args:
texts: List of strings to create embeddings from.
Expand All @@ -60,14 +65,15 @@ def embed(
text_embedding_items = [{"text": text, "metadata": None} for text in texts]
return self.embed_items(text_embedding_items=text_embedding_items, model=model)

def embed_and_join_metadata(
def embed_and_join_metadata_by_columns(
self,
texts: List[str],
metadatas: List[Dict],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some texts and creates embeddings using a model in starpoint. Metadata is joined with
the results for ergonomics. Under the hood this is using embed_items.
"""Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items`
instead, as mismatched `texts` and `metadatas` will output undesirable results.
Under the hood this is using `embed_items`.
Args:
texts: List of strings to create embeddings from.
Expand All @@ -82,6 +88,9 @@ def embed_and_join_metadata(
requests.exceptions.SSLError: Failure likely due to network issues.
"""
# TODO: add len + logging here
if len(texts) != len(metadatas):
LOGGER.warning(TEXT_METADATA_LENGTH_MISMATCH_WARNING)

text_embedding_items = [
{
"text": text,
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_embedding_embed(

@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata(
def test_embedding_embed_and_join_metadata_by_columns(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
Expand All @@ -110,7 +110,7 @@ def test_embedding_embed_and_join_metadata(
input_model = embedding.EmbeddingModel.MINILM
expected_item = [{"text": input_text, "metadata": input_metadata}]

actual_json = mock_embedding_client.embed_and_join_metadata(
actual_json = mock_embedding_client.embed_and_join_metadata_by_columns(
[input_text], [input_metadata], input_model
)

Expand Down

0 comments on commit 908e71f

Please sign in to comment.