Skip to content

Commit

Permalink
Fix bugs and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FullMetalMeowchemist committed Sep 16, 2023
1 parent 81a86da commit 4a01bf4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
15 changes: 6 additions & 9 deletions python/starpoint/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,11 @@ def embed_and_join_metadata(
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
if not text_embedding_items:
raise ValueError("text_embedding_items received an empty list.")

texts = list(map(lambda item: item.get(embedding_key), text_embedding_items))
if not texts:
raise ValueError(
"text_embedding_items received an empty list of list of empty items."
)
elif not all(texts):
if not all(texts):
unqualified_indices = list(
more_itertools.locate(texts, lambda x: x is None)
)
Expand All @@ -143,12 +142,10 @@ def embed_and_join_metadata(

# We can also do this operation in the first map that creates texts, but that might make additional operations
# in here a lot more annoying. It's an optimization that shouldn't happen right now.
metadatas = list(
map(lambda item: item.pop(embedding_key), text_embedding_items)
)
list(map(lambda item: item.pop(embedding_key), text_embedding_items))

return self.embed_and_join_metadata_by_columns(
texts=text_embedding_items, metadatas=metadatas, model=model
texts=texts, metadatas=text_embedding_items, model=model
)

def embed_items(
Expand Down
68 changes: 68 additions & 0 deletions python/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,74 @@ def test_embedding_embed_and_join_metadata_by_columns_mismatch_list(
)


@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata(
requests_mock: MagicMock,
embed_and_join_metadata_by_columns_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "embed text"
input_metadata = {
"metadata1": "metadata1",
"metadata2": "metadata2",
}
embed_key = "text"
input_dict = {embed_key: input_text}
input_dict.update(input_metadata)
test_embedding_items = [input_dict]

expected_item = [{"text": input_text, "metadata": input_metadata}]
input_model = embedding.EmbeddingModel.MINILM

actual_json = mock_embedding_client.embed_and_join_metadata(
test_embedding_items, embed_key, input_model
)

embed_and_join_metadata_by_columns_mock.assert_called_once_with(
texts=[input_text], metadatas=[input_metadata], model=input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_no_embed_key(
requests_mock: MagicMock,
embed_and_join_metadata_by_columns_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "embed text"
input_metadata = {
"metadata1": "metadata1",
"metadata2": "metadata2",
}
input_dict = {"text": input_text}
input_dict.update(input_metadata)
test_embedding_items = [input_dict]

input_model = embedding.EmbeddingModel.MINILM
embed_key = "no key"

with pytest.raises(ValueError):
actual_json = mock_embedding_client.embed_and_join_metadata(
test_embedding_items, embed_key, input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_no_values(
requests_mock: MagicMock,
embed_and_join_metadata_by_columns_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_model = embedding.EmbeddingModel.MINILM
with pytest.raises(ValueError):
actual_json = mock_embedding_client.embed_and_join_metadata(
[], "text", input_model
)


@patch("starpoint.embedding.requests")
def test_embedding_embed_items(
requests_mock: MagicMock,
Expand Down

0 comments on commit 4a01bf4

Please sign in to comment.