Skip to content

Commit

Permalink
Test embed_array_storage with list types
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova committed Aug 6, 2024
1 parent 5c8646b commit f11c56d
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from decimal import Decimal
from functools import partial
from typing import List, Union
from unittest.mock import MagicMock

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -1348,6 +1349,39 @@ def test_embed_array_storage_nested(image_file):
assert isinstance(embedded_images_array.to_pylist()[0]["foo"]["bytes"], bytes)


@pytest.mark.parametrize(
"array, feature, expected_embedded_array_type",
[
(
pa.array([[{"path": "image_path"}]], type=pa.list_(Image.pa_type)),
[Image()],
pa.types.is_list,
),
(
pa.array([[{"path": "image_path"}]], type=pa.large_list(Image.pa_type)),
LargeList(Image()),
pa.types.is_large_list,
),
(
pa.array([[{"path": "image_path"}]], type=pa.list_(Image.pa_type)),
Sequence(Image()),
pa.types.is_list,
),
],
)
def test_embed_array_storage_with_list_types(array, feature, expected_embedded_array_type, monkeypatch):
mock_embed_storage = MagicMock(
return_value=pa.StructArray.from_arrays(
[pa.array([b"image_bytes"], type=pa.binary()), pa.array(["image_path"], type=pa.string())],
["bytes", "path"],
)
)
monkeypatch.setattr(Image, "embed_storage", mock_embed_storage)
embedded_images_array = embed_array_storage(array, feature)
assert expected_embedded_array_type(embedded_images_array.type)
assert embedded_images_array.to_pylist() == [[{"bytes": b"image_bytes", "path": "image_path"}]]


def test_embed_table_storage(image_file):
features = Features({"image": Image()})
table = table_cast(pa.table({"image": [image_file]}), features.arrow_schema)
Expand Down

0 comments on commit f11c56d

Please sign in to comment.