From 0cf0be8906063d09456285be9c9f7ce5789726ae Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:43:44 +0200 Subject: [PATCH] Support pyarrow large_list (#7019) * Test polars round trip * Test Features.from_arrow_schema * Add large attribute to Sequence * Update get_nested_type to support pa.large_list * Update generate_from_arrow_type to support pa.LargeListType * Fix typo * Rename test * Add require_polars to test * Test from_polars large_list * Update test array_cast with large list * Support large list in array_cast * Test cast_array_to_feature for large list * Support large list in cast_array_to_feature * Fix support large list in cast_array_to_feature * Test save_to_disk with a dataset from polars with large_list * Test Features.reorder_fields_as with large Sequence * Fix Features.reorder_fields_as by using all Sequence params * Test save_to/load_from disk round trip with large_list dataset * Test DatasetInfo.from_dict with large Sequence * Test Features to/from dict round trip with large Sequence * Fix features generate_from_dict by using all Sequence params * Remove debug comments * Test cast_array_to_feature with struct array * Fix cast_array_to_feature for struct array * Test cast_array_to_feature from/to the same Sequence feature dtype * Fix cast_array_to_feature for the same Sequence feature dtype * Add more tests for dataset with large Sequence * Remove Sequence.large * Remove Sequence.large from tests * Add LargeList to tests * Replace tests with Sequence.large with LargeList * Replace Sequence.large with LargeList in test_dataset_info_from_dict * Implement LargeList * Test features to_yaml_list with LargeList * Support LargeList in Features._to_yaml_list * Test Features.from_dict with LargeList * Support LargeList in Features.from_dict * Test Features from_yaml_list with LargeList * Support LargeList in Features._from_yaml_list * Test get_nested_type with scalar/list features * Support LargeList in get_nested_type * Test generate_from_arrow_type with primitive/nested data types * Support LargeList in generate_from_arrow_type * Remove Sequence of dict from test cast_array_to_feature * Support LargeList in cast_array_to_feature * Test Features.encode_example * Test encode_nested_example with list types * Support LargeList in encode_nested_example * Test check_non_null_non_empty_recursive with list types * Support LargeList in check_non_null_non_empty_recursive * Test require_decoding with list types * Support LargeList in require_decoding * Test decode_nested_example with list types * Support LargeList in decode_nested_example * Test generate_from_dict with list types * Test Features.from_dict with list types * Test _visit with list types * Support LargeList in _visit * Test require_storage_cast with list types * Support LargeList in require_storage_cast * Refactor test_require_storage_cast_with_list_types * Test require_storage_embed with list types * Support LargeList in require_storage_embed * Fix test_features_reorder_fields_as * Test Features.reorder_fields_as with list types * Test Features.reorder_fields_as with dict within list types * Support LargeList in Features.reorder_fields_as * Test Features.flatten with list types * Test embed_array_storage with list types * Support LargeList in embed_array_storage * Delete unused tf_utils.is_numeric_feature * Add LargeList docstring * Add LargeList to main classes docs * Address requested changes --- .../source/package_reference/main_classes.mdx | 6 +- src/datasets/features/__init__.py | 3 +- src/datasets/features/features.py | 170 ++++++---- src/datasets/table.py | 46 ++- src/datasets/utils/tf_utils.py | 18 -- tests/features/test_features.py | 306 +++++++++++++++++- tests/test_arrow_dataset.py | 72 +++++ tests/test_info.py | 13 + tests/test_table.py | 103 +++++- 9 files changed, 643 insertions(+), 94 deletions(-) diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 9c964bc56d5..86257f624b4 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -211,11 +211,13 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable [[autodoc]] datasets.Features -[[autodoc]] datasets.Sequence +[[autodoc]] datasets.Value [[autodoc]] datasets.ClassLabel -[[autodoc]] datasets.Value +[[autodoc]] datasets.LargeList + +[[autodoc]] datasets.Sequence [[autodoc]] datasets.Translation diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index b3c03fbfed7..35ebfb4ac0c 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -6,6 +6,7 @@ "Array5D", "ClassLabel", "Features", + "LargeList", "Sequence", "Value", "Image", @@ -13,6 +14,6 @@ "TranslationVariableLanguages", ] from .audio import Audio -from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, Sequence, Value +from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value from .image import Image from .translation import Translation, TranslationVariableLanguages diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index df8eeffd306..dc7c0f8c850 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1163,6 +1163,24 @@ class Sequence: _type: str = field(default="Sequence", init=False, repr=False) +@dataclass +class LargeList: + """Feature type for large list data composed of child feature data type. + + It is backed by `pyarrow.LargeListType`, which is like `pyarrow.ListType` but with 64-bit rather than 32-bit offsets. + + Args: + dtype: + Child feature data type of each item within the large list. + """ + + dtype: Any + id: Optional[str] = None + # Automatically constructed + pa_type: ClassVar[Any] = None + _type: str = field(default="LargeList", init=False, repr=False) + + FeatureType = Union[ dict, list, @@ -1171,6 +1189,7 @@ class Sequence: ClassLabel, Translation, TranslationVariableLanguages, + LargeList, Sequence, Array2D, Array3D, @@ -1188,12 +1207,14 @@ def _check_non_null_non_empty_recursive(obj, schema: Optional[FeatureType] = Non """ if obj is None: return False - elif isinstance(obj, (list, tuple)) and (schema is None or isinstance(schema, (list, tuple, Sequence))): + elif isinstance(obj, (list, tuple)) and (schema is None or isinstance(schema, (list, tuple, LargeList, Sequence))): if len(obj) > 0: if schema is None: pass elif isinstance(schema, (list, tuple)): schema = schema[0] + elif isinstance(schema, LargeList): + schema = schema.dtype else: schema = schema.feature return _check_non_null_non_empty_recursive(obj[0], schema) @@ -1225,12 +1246,17 @@ def get_nested_type(schema: FeatureType) -> pa.DataType: raise ValueError("When defining list feature, you should just provide one example of the inner type") value_type = get_nested_type(schema[0]) return pa.list_(value_type) + elif isinstance(schema, LargeList): + value_type = get_nested_type(schema.dtype) + return pa.large_list(value_type) elif isinstance(schema, Sequence): value_type = get_nested_type(schema.feature) # We allow to reverse list of dict => dict of list for compatibility with tfds if isinstance(schema.feature, dict): - return pa.struct({f.name: pa.list_(f.type, schema.length) for f in value_type}) - return pa.list_(value_type, schema.length) + data_type = pa.struct({f.name: pa.list_(f.type, schema.length) for f in value_type}) + else: + data_type = pa.list_(value_type, schema.length) + return data_type # Other objects are callable which returns their data type (ClassLabel, Array2D, Translation, Arrow datatype creation methods) return schema() @@ -1267,10 +1293,22 @@ def encode_nested_example(schema, obj, level=0): if encode_nested_example(sub_schema, first_elmt, level=level + 1) != first_elmt: return [encode_nested_example(sub_schema, o, level=level + 1) for o in obj] return list(obj) + elif isinstance(schema, LargeList): + if obj is None: + return None + else: + if len(obj) > 0: + sub_schema = schema.dtype + for first_elmt in obj: + if _check_non_null_non_empty_recursive(first_elmt, sub_schema): + break + if encode_nested_example(sub_schema, first_elmt, level=level + 1) != first_elmt: + return [encode_nested_example(sub_schema, o, level=level + 1) for o in obj] + return list(obj) elif isinstance(schema, Sequence): if obj is None: return None - # We allow to reverse list of dict => dict of list for compatiblity with tfds + # We allow to reverse list of dict => dict of list for compatibility with tfds if isinstance(schema.feature, dict): # dict of list to fill list_dict = {} @@ -1337,8 +1375,20 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni if decode_nested_example(sub_schema, first_elmt) != first_elmt: return [decode_nested_example(sub_schema, o) for o in obj] return list(obj) + elif isinstance(schema, LargeList): + if obj is None: + return None + else: + sub_schema = schema.dtype + if len(obj) > 0: + for first_elmt in obj: + if _check_non_null_non_empty_recursive(first_elmt, sub_schema): + break + if decode_nested_example(sub_schema, first_elmt) != first_elmt: + return [decode_nested_example(sub_schema, o) for o in obj] + return list(obj) elif isinstance(schema, Sequence): - # We allow to reverse list of dict => dict of list for compatiblity with tfds + # We allow to reverse list of dict => dict of list for compatibility with tfds if isinstance(schema.feature, dict): return {k: decode_nested_example([schema.feature[k]], obj[k]) for k in schema.feature} else: @@ -1356,6 +1406,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni ClassLabel.__name__: ClassLabel, Translation.__name__: Translation, TranslationVariableLanguages.__name__: TranslationVariableLanguages, + LargeList.__name__: LargeList, Sequence.__name__: Sequence, Array2D.__name__: Array2D, Array3D.__name__: Array3D, @@ -1406,8 +1457,12 @@ def generate_from_dict(obj: Any): if class_type is None: raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}") + if class_type == LargeList: + dtype = obj.pop("dtype") + return LargeList(generate_from_dict(dtype), **obj) if class_type == Sequence: - return Sequence(feature=generate_from_dict(obj["feature"]), length=obj.get("length", -1)) + feature = obj.pop("feature") + return Sequence(feature=generate_from_dict(feature), **obj) field_names = {f.name for f in fields(class_type)} return class_type(**{k: v for k, v in obj.items() if k in field_names}) @@ -1432,6 +1487,9 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType: if isinstance(feature, (dict, tuple, list)): return [feature] return Sequence(feature=feature) + elif isinstance(pa_type, pa.LargeListType): + dtype = generate_from_arrow_type(pa_type.value_type) + return LargeList(dtype) elif isinstance(pa_type, _ArrayXDExtensionType): array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims] return array_feature(shape=pa_type.shape, dtype=pa_type.value_type) @@ -1537,6 +1595,8 @@ def _visit(feature: FeatureType, func: Callable[[FeatureType], Optional[FeatureT out = func({k: _visit(f, func) for k, f in feature.items()}) elif isinstance(feature, (list, tuple)): out = func([_visit(feature[0], func)]) + elif isinstance(feature, LargeList): + out = func(LargeList(_visit(feature.dtype, func))) elif isinstance(feature, Sequence): out = func(Sequence(_visit(feature.feature, func), length=feature.length)) else: @@ -1558,6 +1618,8 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False return any(require_decoding(f) for f in feature.values()) elif isinstance(feature, (list, tuple)): return require_decoding(feature[0]) + elif isinstance(feature, LargeList): + return require_decoding(feature.dtype) elif isinstance(feature, Sequence): return require_decoding(feature.feature) else: @@ -1576,6 +1638,8 @@ def require_storage_cast(feature: FeatureType) -> bool: return any(require_storage_cast(f) for f in feature.values()) elif isinstance(feature, (list, tuple)): return require_storage_cast(feature[0]) + elif isinstance(feature, LargeList): + return require_storage_cast(feature.dtype) elif isinstance(feature, Sequence): return require_storage_cast(feature.feature) else: @@ -1594,6 +1658,8 @@ def require_storage_embed(feature: FeatureType) -> bool: return any(require_storage_cast(f) for f in feature.values()) elif isinstance(feature, (list, tuple)): return require_storage_cast(feature[0]) + elif isinstance(feature, LargeList): + return require_storage_cast(feature.dtype) elif isinstance(feature, Sequence): return require_storage_cast(feature.feature) else: @@ -1641,7 +1707,7 @@ class Features(dict): A [`~datasets.Sequence`] with a internal dictionary feature will be automatically converted into a dictionary of - lists. This behavior is implemented to have a compatilbity layer with the TensorFlow Datasets library but may be + lists. This behavior is implemented to have a compatibility layer with the TensorFlow Datasets library but may be un-wanted in some cases. If you don't want this behavior, you can use a python `list` instead of the [`~datasets.Sequence`]. @@ -1771,37 +1837,22 @@ def simplify(feature: dict) -> dict: if not isinstance(feature, dict): raise TypeError(f"Expected a dict but got a {type(feature)}: {feature}") - # - # sequence: -> sequence: int32 - # dtype: int32 -> - # - if isinstance(feature.get("sequence"), dict) and list(feature["sequence"]) == ["dtype"]: - feature["sequence"] = feature["sequence"]["dtype"] - - # - # sequence: -> sequence: - # struct: -> - name: foo - # - name: foo -> dtype: int32 - # dtype: int32 -> - # - if isinstance(feature.get("sequence"), dict) and list(feature["sequence"]) == ["struct"]: - feature["sequence"] = feature["sequence"]["struct"] - - # - # list: -> list: int32 - # dtype: int32 -> - # - if isinstance(feature.get("list"), dict) and list(feature["list"]) == ["dtype"]: - feature["list"] = feature["list"]["dtype"] - - # - # list: -> list: - # struct: -> - name: foo - # - name: foo -> dtype: int32 - # dtype: int32 -> - # - if isinstance(feature.get("list"), dict) and list(feature["list"]) == ["struct"]: - feature["list"] = feature["list"]["struct"] + for list_type in ["large_list", "list", "sequence"]: + # + # list_type: -> list_type: int32 + # dtype: int32 -> + # + if isinstance(feature.get(list_type), dict) and list(feature[list_type]) == ["dtype"]: + feature[list_type] = feature[list_type]["dtype"] + + # + # list_type: -> list_type: + # struct: -> - name: foo + # - name: foo -> dtype: int32 + # dtype: int32 -> + # + if isinstance(feature.get(list_type), dict) and list(feature[list_type]) == ["struct"]: + feature[list_type] = feature[list_type]["struct"] # # class_label: -> class_label: @@ -1819,7 +1870,10 @@ def simplify(feature: dict) -> dict: def to_yaml_inner(obj: Union[dict, list]) -> dict: if isinstance(obj, dict): _type = obj.pop("_type", None) - if _type == "Sequence": + if _type == "LargeList": + value_type = obj.pop("dtype") + return simplify({"large_list": to_yaml_inner(value_type), **obj}) + elif _type == "Sequence": _feature = obj.pop("feature") return simplify({"sequence": to_yaml_inner(_feature), **obj}) elif _type == "Value": @@ -1858,18 +1912,14 @@ def _from_yaml_list(cls, yaml_data: list) -> "Features": def unsimplify(feature: dict) -> dict: if not isinstance(feature, dict): raise TypeError(f"Expected a dict but got a {type(feature)}: {feature}") - # - # sequence: int32 -> sequence: - # -> dtype: int32 - # - if isinstance(feature.get("sequence"), str): - feature["sequence"] = {"dtype": feature["sequence"]} - # - # list: int32 -> list: - # -> dtype: int32 - # - if isinstance(feature.get("list"), str): - feature["list"] = {"dtype": feature["list"]} + + for list_type in ["large_list", "list", "sequence"]: + # + # list_type: int32 -> list_type: + # -> dtype: int32 + # + if isinstance(feature.get(list_type), str): + feature[list_type] = {"dtype": feature[list_type]} # # class_label: -> class_label: @@ -1891,6 +1941,9 @@ def from_yaml_inner(obj: Union[dict, list]) -> Union[dict, list]: if not obj: return {} _type = next(iter(obj)) + if _type == "large_list": + _dtype = unsimplify(obj).pop(_type) + return {"dtype": from_yaml_inner(_dtype), **obj, "_type": "LargeList"} if _type == "sequence": _feature = unsimplify(obj).pop(_type) return {"feature": from_yaml_inner(_feature), **obj, "_type": "Sequence"} @@ -2073,11 +2126,11 @@ def reorder_fields_as(self, other: "Features") -> "Features": Example:: >>> from datasets import Features, Sequence, Value - >>> # let's say we have to features with a different order of nested fields (for a and b for example) + >>> # let's say we have two features with a different order of nested fields (for a and b for example) >>> f1 = Features({"root": Sequence({"a": Value("string"), "b": Value("string")})}) >>> f2 = Features({"root": {"b": Sequence(Value("string")), "a": Sequence(Value("string"))}}) >>> assert f1.type != f2.type - >>> # re-ordering keeps the base structure (here Sequence is defined at the root level), but make the fields order match + >>> # re-ordering keeps the base structure (here Sequence is defined at the root level), but makes the fields order match >>> f1.reorder_fields_as(f2) {'root': Sequence(feature={'b': Value(dtype='string', id=None), 'a': Value(dtype='string', id=None)}, length=-1, id=None)} >>> assert f1.reorder_fields_as(f2).type == f2.type @@ -2092,15 +2145,16 @@ def recursive_reorder(source, target, stack=""): else: target = [target] if isinstance(source, Sequence): - source, id_, length = source.feature, source.id, source.length + sequence_kwargs = vars(source).copy() + source = sequence_kwargs.pop("feature") if isinstance(source, dict): source = {k: [v] for k, v in source.items()} reordered = recursive_reorder(source, target, stack) - return Sequence({k: v[0] for k, v in reordered.items()}, id=id_, length=length) + return Sequence({k: v[0] for k, v in reordered.items()}, **sequence_kwargs) else: source = [source] reordered = recursive_reorder(source, target, stack) - return Sequence(reordered[0], id=id_, length=length) + return Sequence(reordered[0], **sequence_kwargs) elif isinstance(source, dict): if not isinstance(target, dict): raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position) @@ -2118,6 +2172,10 @@ def recursive_reorder(source, target, stack=""): if len(source) != len(target): raise ValueError(f"Length mismatch: between {source} and {target}" + stack_position) return [recursive_reorder(source[i], target[i], stack + ".") for i in range(len(target))] + elif isinstance(source, LargeList): + if not isinstance(target, LargeList): + raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position) + return LargeList(recursive_reorder(source.dtype, target.dtype, stack)) else: return source diff --git a/src/datasets/table.py b/src/datasets/table.py index b5604b998e7..e9e27544220 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1884,7 +1884,7 @@ def array_cast( return array arrays = [_c(array.field(field.name), field.type) for field in pa_type] return pa.StructArray.from_arrays(arrays, fields=list(pa_type), mask=array.is_null()) - elif pa.types.is_list(array.type): + elif pa.types.is_list(array.type) or pa.types.is_large_list(array.type): if pa.types.is_fixed_size_list(pa_type): if _are_list_values_of_length(array, pa_type.list_size): if array.null_count > 0: @@ -1911,6 +1911,10 @@ def array_cast( # Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError array_offsets = _combine_list_array_offsets_with_mask(array) return pa.ListArray.from_arrays(array_offsets, _c(array.values, pa_type.value_type)) + elif pa.types.is_large_list(pa_type): + # Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError + array_offsets = _combine_list_array_offsets_with_mask(array) + return pa.LargeListArray.from_arrays(array_offsets, _c(array.values, pa_type.value_type)) elif pa.types.is_fixed_size_list(array.type): if pa.types.is_fixed_size_list(pa_type): if pa_type.list_size == array.type.list_size: @@ -1923,6 +1927,11 @@ def array_cast( elif pa.types.is_list(pa_type): array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size return pa.ListArray.from_arrays(array_offsets, _c(array.values, pa_type.value_type), mask=array.is_null()) + elif pa.types.is_large_list(pa_type): + array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size + return pa.LargeListArray.from_arrays( + array_offsets, _c(array.values, pa_type.value_type), mask=array.is_null() + ) else: if pa.types.is_string(pa_type): if not allow_primitive_to_str and pa.types.is_primitive(array.type): @@ -1972,7 +1981,7 @@ def cast_array_to_feature( Returns: array (`pyarrow.Array`): the casted array """ - from .features.features import Sequence, get_nested_type + from .features.features import LargeList, Sequence, get_nested_type _c = partial( cast_array_to_feature, @@ -1988,24 +1997,34 @@ def cast_array_to_feature( elif pa.types.is_struct(array.type): # feature must be a dict or Sequence(subfeatures_dict) if isinstance(feature, Sequence) and isinstance(feature.feature, dict): - feature = { - name: Sequence(subfeature, length=feature.length) for name, subfeature in feature.feature.items() - } + sequence_kwargs = vars(feature).copy() + feature = sequence_kwargs.pop("feature") + feature = {name: Sequence(subfeature, **sequence_kwargs) for name, subfeature in feature.items()} if isinstance(feature, dict) and {field.name for field in array.type} == set(feature): if array.type.num_fields == 0: return array arrays = [_c(array.field(name), subfeature) for name, subfeature in feature.items()] return pa.StructArray.from_arrays(arrays, names=list(feature), mask=array.is_null()) - elif pa.types.is_list(array.type): - # feature must be either [subfeature] or Sequence(subfeature) + elif pa.types.is_list(array.type) or pa.types.is_large_list(array.type): + # feature must be either [subfeature] or LargeList(subfeature) or Sequence(subfeature) if isinstance(feature, list): casted_array_values = _c(array.values, feature[0]) - if casted_array_values.type == array.values.type: + if pa.types.is_list(array.type) and casted_array_values.type == array.values.type: + # Both array and feature have equal list type and values (within the list) type return array else: # Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError array_offsets = _combine_list_array_offsets_with_mask(array) return pa.ListArray.from_arrays(array_offsets, casted_array_values) + elif isinstance(feature, LargeList): + casted_array_values = _c(array.values, feature.dtype) + if pa.types.is_large_list(array.type) and casted_array_values.type == array.values.type: + # Both array and feature have equal large_list type and values (within the list) type + return array + else: + # Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError + array_offsets = _combine_list_array_offsets_with_mask(array) + return pa.LargeListArray.from_arrays(array_offsets, casted_array_values) elif isinstance(feature, Sequence): if feature.length > -1: if _are_list_values_of_length(array, feature.length): @@ -2042,7 +2061,8 @@ def cast_array_to_feature( return pa.FixedSizeListArray.from_arrays(_c(array_values, feature.feature), feature.length) else: casted_array_values = _c(array.values, feature.feature) - if casted_array_values.type == array.values.type: + if pa.types.is_list(array.type) and casted_array_values.type == array.values.type: + # Both array and feature have equal list type and values (within the list) type return array else: # Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError @@ -2053,6 +2073,9 @@ def cast_array_to_feature( if isinstance(feature, list): array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size return pa.ListArray.from_arrays(array_offsets, _c(array.values, feature[0]), mask=array.is_null()) + elif isinstance(feature, LargeList): + array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size + return pa.LargeListArray.from_arrays(array_offsets, _c(array.values, feature.dtype), mask=array.is_null()) elif isinstance(feature, Sequence): if feature.length > -1: if feature.length == array.type.list_size: @@ -2128,6 +2151,11 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"): return pa.ListArray.from_arrays(array_offsets, _e(array.values, feature[0])) if isinstance(feature, Sequence) and feature.length == -1: return pa.ListArray.from_arrays(array_offsets, _e(array.values, feature.feature)) + elif pa.types.is_large_list(array.type): + # feature must be LargeList(subfeature) + # Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError + array_offsets = _combine_list_array_offsets_with_mask(array) + return pa.LargeListArray.from_arrays(array_offsets, _e(array.values, feature.dtype)) elif pa.types.is_fixed_size_list(array.type): # feature must be Sequence(subfeature) if isinstance(feature, Sequence) and feature.length > -1: diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index b69f5c85b2c..2de35a943e7 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -67,24 +67,6 @@ def is_numeric_pa_type(pa_type): return pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type) or pa.types.is_decimal(pa_type) -def is_numeric_feature(feature): - from .. import ClassLabel, Sequence, Value - from ..features.features import _ArrayXD - - if isinstance(feature, Sequence): - return is_numeric_feature(feature.feature) - elif isinstance(feature, list): - return is_numeric_feature(feature[0]) - elif isinstance(feature, _ArrayXD): - return is_numeric_pa_type(feature().storage_dtype) - elif isinstance(feature, Value): - return is_numeric_pa_type(feature()) - elif isinstance(feature, ClassLabel): - return True - else: - return False - - def np_get_batch( indices, dataset, cols_to_retain, collate_fn, collate_fn_args, columns_to_np_types, return_dict=False ): diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 913d8edeebe..8ab4baced7a 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -1,7 +1,7 @@ import datetime from typing import List, Tuple from unittest import TestCase -from unittest.mock import patch +from unittest.mock import MagicMock, patch import numpy as np import pandas as pd @@ -10,15 +10,23 @@ from datasets import Array2D from datasets.arrow_dataset import Dataset -from datasets.features import Audio, ClassLabel, Features, Image, Sequence, Value +from datasets.features import Audio, ClassLabel, Features, Image, LargeList, Sequence, Value from datasets.features.features import ( _align_features, _arrow_to_datasets_dtype, _cast_to_python_objects, _check_if_features_can_be_aligned, + _check_non_null_non_empty_recursive, + _visit, cast_to_python_objects, + decode_nested_example, encode_nested_example, + generate_from_arrow_type, generate_from_dict, + get_nested_type, + require_decoding, + require_storage_cast, + require_storage_embed, string_to_arrow, ) from datasets.features.translation import Translation, TranslationVariableLanguages @@ -28,6 +36,10 @@ from ..utils import require_jax, require_numpy1_on_windows, require_tf, require_torch +def list_with(item): + return [item] + + class FeaturesTest(TestCase): def test_from_arrow_schema_simple(self): data = {"a": [{"b": {"c": "text"}}] * 10, "foo": [1] * 10} @@ -386,6 +398,28 @@ def test_class_label_to_and_from_dict(class_label_arg, tmp_path_factory): assert generated_class_label == class_label +@pytest.mark.parametrize( + "schema", + [[Audio()], LargeList(Audio()), Sequence(Audio())], +) +def test_decode_nested_example_with_list_types(schema, monkeypatch): + mock_decode_example = MagicMock() + monkeypatch.setattr(Audio, "decode_example", mock_decode_example) + audio_example = {"path": "dummy_audio_path"} + _ = decode_nested_example(schema, [audio_example]) + assert mock_decode_example.called + assert mock_decode_example.call_args.args[0] == audio_example + + +@pytest.mark.parametrize( + "schema", + [[ClassLabel(names=["a", "b"])], LargeList(ClassLabel(names=["a", "b"])), Sequence(ClassLabel(names=["a", "b"]))], +) +def test_encode_nested_example_with_list_types(schema): + result = encode_nested_example(schema, ["b"]) + assert result == [1] + + @pytest.mark.parametrize("inner_type", [Value("int32"), {"subcolumn": Value("int32")}]) def test_encode_nested_example_sequence_with_none(inner_type): schema = Sequence(inner_type) @@ -394,6 +428,21 @@ def test_encode_nested_example_sequence_with_none(inner_type): assert result is None +@pytest.mark.parametrize( + "features_dict, example, expected_encoded_example", + [ + ({"col_1": ClassLabel(names=["a", "b"])}, {"col_1": "b"}, {"col_1": 1}), + ({"col_1": [ClassLabel(names=["a", "b"])]}, {"col_1": ["b"]}, {"col_1": [1]}), + ({"col_1": LargeList(ClassLabel(names=["a", "b"]))}, {"col_1": ["b"]}, {"col_1": [1]}), + ({"col_1": Sequence(ClassLabel(names=["a", "b"]))}, {"col_1": ["b"]}, {"col_1": [1]}), + ], +) +def test_encode_example(features_dict, example, expected_encoded_example): + features = Features(features_dict) + encoded_example = features.encode_example(example) + assert encoded_example == expected_encoded_example + + def test_encode_batch_with_example_with_empty_first_elem(): features = Features( { @@ -624,6 +673,8 @@ def test_dont_iterate_over_each_element_in_a_list(self, mocked_cast): Features({"foo": Sequence({"bar": Value("int32")})}), Features({"foo": [Value("int32")]}), Features({"foo": [{"bar": Value("int32")}]}), + Features({"foo": LargeList(Value("int32"))}), + Features({"foo": LargeList({"bar": Value("int32")})}), ] NESTED_CUSTOM_FEATURES = [ @@ -632,11 +683,13 @@ def test_dont_iterate_over_each_element_in_a_list(self, mocked_cast): Features({"foo": Sequence({"bar": ClassLabel(names=["negative", "positive"])})}), Features({"foo": [ClassLabel(names=["negative", "positive"])]}), Features({"foo": [{"bar": ClassLabel(names=["negative", "positive"])}]}), + Features({"foo": LargeList(ClassLabel(names=["negative", "positive"]))}), + Features({"foo": LargeList({"bar": ClassLabel(names=["negative", "positive"])})}), ] @pytest.mark.parametrize("features", SIMPLE_FEATURES + CUSTOM_FEATURES + NESTED_FEATURES + NESTED_CUSTOM_FEATURES) -def test_features_to_dict(features: Features): +def test_features_to_dict_and_from_dict_round_trip(features: Features): features_dict = features.to_dict() assert isinstance(features_dict, dict) reloaded = Features.from_dict(features_dict) @@ -651,6 +704,119 @@ def test_features_to_yaml_list(features: Features): assert features == reloaded +@pytest.mark.parametrize( + "features_dict, expected_features_dict", + [ + ({"col": [{"sub_col": Value("int32")}]}, {"col": [{"sub_col": Value("int32")}]}), + ({"col": LargeList({"sub_col": Value("int32")})}, {"col": LargeList({"sub_col": Value("int32")})}), + ({"col": Sequence({"sub_col": Value("int32")})}, {"col.sub_col": Sequence(Value("int32"))}), + ], +) +def test_features_flatten_with_list_types(features_dict, expected_features_dict): + features = Features(features_dict) + flattened_features = features.flatten() + assert flattened_features == Features(expected_features_dict) + + +@pytest.mark.parametrize( + "deserialized_features_dict, expected_features_dict", + [ + ( + {"col": [{"dtype": "int32", "_type": "Value"}]}, + {"col": [Value("int32")]}, + ), + ( + {"col": {"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}}, + {"col": LargeList(Value("int32"))}, + ), + ( + {"col": {"feature": {"dtype": "int32", "_type": "Value"}, "_type": "Sequence"}}, + {"col": Sequence(Value("int32"))}, + ), + ( + {"col": [{"sub_col": {"dtype": "int32", "_type": "Value"}}]}, + {"col": [{"sub_col": Value("int32")}]}, + ), + ( + {"col": {"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}}, + {"col": LargeList({"sub_col": Value("int32")})}, + ), + ( + {"col": {"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "Sequence"}}, + {"col": Sequence({"sub_col": Value("int32")})}, + ), + ], +) +def test_features_from_dict_with_list_types(deserialized_features_dict, expected_features_dict): + features = Features.from_dict(deserialized_features_dict) + assert features == Features(expected_features_dict) + + +@pytest.mark.parametrize( + "deserialized_feature_dict, expected_feature", + [ + ( + [{"dtype": "int32", "_type": "Value"}], + [Value("int32")], + ), + ( + {"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}, + LargeList(Value("int32")), + ), + ( + {"feature": {"dtype": "int32", "_type": "Value"}, "_type": "Sequence"}, + Sequence(Value("int32")), + ), + ( + [{"sub_col": {"dtype": "int32", "_type": "Value"}}], + [{"sub_col": Value("int32")}], + ), + ( + {"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}, + LargeList({"sub_col": Value("int32")}), + ), + ( + {"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "Sequence"}, + Sequence({"sub_col": Value("int32")}), + ), + ], +) +def test_generate_from_dict_with_list_types(deserialized_feature_dict, expected_feature): + feature = generate_from_dict(deserialized_feature_dict) + assert feature == expected_feature + + +@pytest.mark.parametrize( + "features_dict, expected_features_yaml_list", + [ + ({"col": LargeList(Value("int32"))}, [{"name": "col", "large_list": "int32"}]), + ( + {"col": LargeList({"sub_col": Value("int32")})}, + [{"name": "col", "large_list": [{"dtype": "int32", "name": "sub_col"}]}], + ), + ], +) +def test_features_to_yaml_list_with_large_list(features_dict, expected_features_yaml_list): + features = Features(features_dict) + features_yaml_list = features._to_yaml_list() + assert features_yaml_list == expected_features_yaml_list + + +@pytest.mark.parametrize( + "features_yaml_list, expected_features_dict", + [ + ([{"name": "col", "large_list": "int32"}], {"col": LargeList(Value("int32"))}), + ( + [{"name": "col", "large_list": [{"dtype": "int32", "name": "sub_col"}]}], + {"col": LargeList({"sub_col": Value("int32")})}, + ), + ], +) +def test_features_from_yaml_list_with_large_list(features_yaml_list, expected_features_dict): + features = Features._from_yaml_list(features_yaml_list) + assert features == Features(expected_features_dict) + + @pytest.mark.parametrize("features", SIMPLE_FEATURES + CUSTOM_FEATURES + NESTED_FEATURES + NESTED_CUSTOM_FEATURES) def test_features_to_arrow_schema(features: Features): arrow_schema = features.arrow_schema @@ -696,3 +862,137 @@ def test_features_alignment(features: Tuple[List[Features], Features]): inputs, expected = features _check_if_features_can_be_aligned(inputs) # Check that we can align, will raise otherwise. assert _align_features(inputs) == expected + + +@pytest.mark.parametrize("dtype", [pa.int32, pa.string]) +def test_features_from_arrow_schema_primitive_data_type(dtype): + schema = pa.schema([("column_name", dtype())]) + assert schema == Features.from_arrow_schema(schema).arrow_schema + + +@pytest.mark.parametrize("scalar_dtype", [pa.int32, pa.string]) +@pytest.mark.parametrize("list_dtype", [pa.list_, pa.large_list]) +def test_features_from_arrow_schema_list_data_type(list_dtype, scalar_dtype): + schema = pa.schema([("column_name", list_dtype(scalar_dtype()))]) + assert schema == Features.from_arrow_schema(schema).arrow_schema + + +@pytest.mark.parametrize( + "feature, other_feature", + [ + ([Value("int64")], [Value("int64")]), + (LargeList(Value("int64")), LargeList(Value("int64"))), + (Sequence(Value("int64")), Sequence(Value("int64"))), + ( + [{"sub_col_1": Value("int64"), "sub_col_2": Value("int64")}], + [{"sub_col_2": Value("int64"), "sub_col_1": Value("int64")}], + ), + ( + LargeList({"sub_col_1": Value("int64"), "sub_col_2": Value("int64")}), + LargeList({"sub_col_2": Value("int64"), "sub_col_1": Value("int64")}), + ), + ( + Sequence({"sub_col_1": Value("int64"), "sub_col_2": Value("int64")}), + Sequence({"sub_col_2": Value("int64"), "sub_col_1": Value("int64")}), + ), + ], +) +def test_features_reorder_fields_as_with_list_types(feature, other_feature): + features = Features({"col": feature}) + other_features = Features({"col": other_feature}) + new_features = features.reorder_fields_as(other_features) + assert new_features.type == other_features.type + + +@pytest.mark.parametrize( + "feature, expected_arrow_data_type", [(Value("int64"), pa.int64), (Value("string"), pa.string)] +) +def test_get_nested_type_with_scalar_feature(feature, expected_arrow_data_type): + arrow_data_type = get_nested_type(feature) + assert arrow_data_type == expected_arrow_data_type() + + +@pytest.mark.parametrize( + "scalar_feature, expected_arrow_primitive_data_type", [(Value("int64"), pa.int64), (Value("string"), pa.string)] +) +@pytest.mark.parametrize( + "list_feature, expected_arrow_nested_data_type", + [(list_with, pa.list_), (LargeList, pa.large_list), (Sequence, pa.list_)], +) +def test_get_nested_type_with_list_feature( + list_feature, expected_arrow_nested_data_type, scalar_feature, expected_arrow_primitive_data_type +): + feature = list_feature(scalar_feature) + arrow_data_type = get_nested_type(feature) + assert arrow_data_type == expected_arrow_nested_data_type(expected_arrow_primitive_data_type()) + + +@pytest.mark.parametrize( + "arrow_primitive_data_type, expected_feature", [(pa.int32, Value("int32")), (pa.string, Value("string"))] +) +def test_generate_from_arrow_type_with_arrow_primitive_data_type(arrow_primitive_data_type, expected_feature): + arrow_data_type = arrow_primitive_data_type() + feature = generate_from_arrow_type(arrow_data_type) + assert feature == expected_feature + + +@pytest.mark.parametrize( + "arrow_primitive_data_type, expected_scalar_feature", [(pa.int32, Value("int32")), (pa.string, Value("string"))] +) +@pytest.mark.parametrize( + "arrow_nested_data_type, expected_list_feature", [(pa.list_, Sequence), (pa.large_list, LargeList)] +) +def test_generate_from_arrow_type_with_arrow_nested_data_type( + arrow_nested_data_type, expected_list_feature, arrow_primitive_data_type, expected_scalar_feature +): + arrow_data_type = arrow_nested_data_type(arrow_primitive_data_type()) + feature = generate_from_arrow_type(arrow_data_type) + expected_feature = expected_list_feature(expected_scalar_feature) + assert feature == expected_feature + + +@pytest.mark.parametrize( + "schema", + [[ClassLabel(names=["a", "b"])], LargeList(ClassLabel(names=["a", "b"])), Sequence(ClassLabel(names=["a", "b"]))], +) +def test_check_non_null_non_empty_recursive_with_list_types(schema): + assert _check_non_null_non_empty_recursive([], schema) is False + + +@pytest.mark.parametrize( + "schema", + [ + [[ClassLabel(names=["a", "b"])]], + LargeList(LargeList(ClassLabel(names=["a", "b"]))), + Sequence(Sequence(ClassLabel(names=["a", "b"]))), + ], +) +def test_check_non_null_non_empty_recursive_with_nested_list_types(schema): + assert _check_non_null_non_empty_recursive([[]], schema) is False + + +@pytest.mark.parametrize("feature", [[Audio()], LargeList(Audio()), Sequence(Audio())]) +def test_require_decoding_with_list_types(feature): + assert require_decoding(feature) + + +@pytest.mark.parametrize("feature", [[Audio()], LargeList(Audio()), Sequence(Audio())]) +def test_require_storage_cast_with_list_types(feature): + assert require_storage_cast(feature) + + +@pytest.mark.parametrize("feature", [[Audio()], LargeList(Audio()), Sequence(Audio())]) +def test_require_storage_embed_with_list_types(feature): + assert require_storage_embed(feature) + + +@pytest.mark.parametrize( + "feature, expected", + [([Value("int32")], [1]), (LargeList(Value("int32")), LargeList(1)), (Sequence(Value("int32")), Sequence(1))], +) +def test_visit_with_list_types(feature, expected): + def func(x): + return 1 if isinstance(x, Value) else x + + result = _visit(feature, func) + assert result == expected diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index f747ef7a980..95fec403a64 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -32,6 +32,7 @@ ClassLabel, Features, Image, + LargeList, Sequence, Translation, TranslationVariableLanguages, @@ -4916,3 +4917,74 @@ def test_dataset_batch(): assert len(batches[2]["text"]) == 2 assert batches[2]["id"] == [8, 9] assert batches[2]["text"] == ["Text 8", "Text 9"] + + +def test_dataset_from_dict_with_large_list(): + data = {"col_1": [[1, 2], [3, 4]]} + features = Features({"col_1": LargeList(Value("int64"))}) + ds = Dataset.from_dict(data, features=features) + assert isinstance(ds, Dataset) + assert pa.types.is_large_list(ds.data.schema.field("col_1").type) + + +def test_dataset_save_to_disk_with_large_list(tmp_path): + data = {"col_1": [[1, 2], [3, 4]]} + features = Features({"col_1": LargeList(Value("int64"))}) + ds = Dataset.from_dict(data, features=features) + dataset_path = tmp_path / "dataset_dir" + ds.save_to_disk(dataset_path) + assert (dataset_path / "data-00000-of-00001.arrow").exists() + + +def test_dataset_save_to_disk_and_load_from_disk_round_trip_with_large_list(tmp_path): + data = {"col_1": [[1, 2], [3, 4]]} + features = Features({"col_1": LargeList(Value("int64"))}) + ds = Dataset.from_dict(data, features=features) + dataset_path = tmp_path / "dataset_dir" + ds.save_to_disk(dataset_path) + assert (dataset_path / "data-00000-of-00001.arrow").exists() + loaded_ds = load_from_disk(dataset_path) + assert len(loaded_ds) == len(ds) + assert loaded_ds.features == ds.features + assert loaded_ds.to_dict() == ds.to_dict() + + +@require_polars +def test_from_polars_with_large_list(): + import polars as pl + + df = pl.from_dict({"col_1": [[1, 2], [3, 4]]}) + ds = Dataset.from_polars(df) + assert isinstance(ds, Dataset) + + +@require_polars +def test_from_polars_save_to_disk_with_large_list(tmp_path): + import polars as pl + + df = pl.from_dict({"col_1": [[1, 2], [3, 4]]}) + ds = Dataset.from_polars(df) + dataset_path = tmp_path / "dataset_dir" + ds.save_to_disk(dataset_path) + assert (dataset_path / "data-00000-of-00001.arrow").exists() + + +@require_polars +def test_from_polars_save_to_disk_and_load_from_disk_round_trip_with_large_list(tmp_path): + import polars as pl + + df = pl.from_dict({"col_1": [[1, 2], [3, 4]]}) + ds = Dataset.from_polars(df) + dataset_path = tmp_path / "dataset_dir" + ds.save_to_disk(dataset_path) + assert (dataset_path / "data-00000-of-00001.arrow").exists() + loaded_ds = load_from_disk(dataset_path) + assert len(loaded_ds) == len(ds) + assert loaded_ds.features == ds.features + assert loaded_ds.to_dict() == ds.to_dict() + + +@require_polars +def test_polars_round_trip(): + ds = Dataset.from_dict({"x": [[1, 2], [3, 4, 5]], "y": ["a", "b"]}) + assert isinstance(Dataset.from_polars(ds.to_polars()), Dataset) diff --git a/tests/test_info.py b/tests/test_info.py index e128011c136..1439a7a3b47 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -5,6 +5,7 @@ from datasets.features.features import Features, Value from datasets.info import DatasetInfo, DatasetInfosDict +from datasets.utils.py_utils import asdict @pytest.mark.parametrize( @@ -164,3 +165,15 @@ def test_from_merge_same_dataset_infos(dataset_info): assert dataset_info == dataset_info_merged else: assert DatasetInfo() == dataset_info_merged + + +def test_dataset_info_from_dict_with_large_list(): + dataset_info_dict = { + "citation": "", + "description": "", + "features": {"col_1": {"dtype": {"dtype": "int64", "_type": "Value"}, "_type": "LargeList"}}, + "homepage": "", + "license": "", + } + dataset_info = DatasetInfo.from_dict(dataset_info_dict) + assert asdict(dataset_info) == dataset_info_dict diff --git a/tests/test_table.py b/tests/test_table.py index 71e2e45af20..c624f98d3f0 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1,14 +1,16 @@ import copy import pickle from decimal import Decimal +from functools import partial from typing import List, Union +from unittest.mock import MagicMock import numpy as np import pyarrow as pa import pytest -from datasets import Sequence, Value -from datasets.features.features import Array2D, Array2DExtensionType, ClassLabel, Features, Image, get_nested_type +from datasets.features import Array2D, ClassLabel, Features, Image, LargeList, Sequence, Value +from datasets.features.features import Array2DExtensionType, get_nested_type from datasets.table import ( ConcatenationTable, InMemoryTable, @@ -1260,6 +1262,60 @@ def test_cast_list_array_to_features_sequence(arr, slice, target_value_feature): assert casted_array.to_pylist() == arr.to_pylist() +@pytest.mark.parametrize("sequence_feature_dtype", ["string", "int64"]) +@pytest.mark.parametrize("from_list_type", ["list", "fixed_size_list", "large_list"]) +@pytest.mark.parametrize("list_within_struct", [False, True]) +def test_cast_array_to_feature_with_list_array_and_sequence_feature( + list_within_struct, from_list_type, sequence_feature_dtype +): + list_type = { + "list": pa.list_, + "fixed_size_list": partial(pa.list_, list_size=2), + "large_list": pa.large_list, + } + primitive_type = { + "string": pa.string(), + "int64": pa.int64(), + } + to_type = "list" + array_data = [0, 1] + array_type = list_type[from_list_type](pa.int64()) + sequence_feature = Value(sequence_feature_dtype) + expected_array_type = list_type[to_type](primitive_type[sequence_feature_dtype]) + if list_within_struct: + array_data = {"col_1": array_data} + array_type = pa.struct({"col_1": array_type}) + sequence_feature = {"col_1": sequence_feature} + expected_array_type = pa.struct({"col_1": expected_array_type}) + feature = Sequence(sequence_feature) + array = pa.array([array_data], type=array_type) + cast_array = cast_array_to_feature(array, feature) + assert cast_array.type == expected_array_type + + +@pytest.mark.parametrize("large_list_feature_value_type", ["string", "int64"]) +@pytest.mark.parametrize("from_list_type", ["list", "fixed_size_list", "large_list"]) +def test_cast_array_to_feature_with_list_array_and_large_list_feature(from_list_type, large_list_feature_value_type): + list_type = { + "list": pa.list_, + "fixed_size_list": partial(pa.list_, list_size=2), + "large_list": pa.large_list, + } + primitive_type = { + "string": pa.string(), + "int64": pa.int64(), + } + to_type = "large_list" + array_data = [0, 1] + array_type = list_type[from_list_type](pa.int64()) + large_list_feature_value = Value(large_list_feature_value_type) + expected_array_type = list_type[to_type](primitive_type[large_list_feature_value_type]) + feature = LargeList(large_list_feature_value) + array = pa.array([array_data], type=array_type) + cast_array = cast_array_to_feature(array, feature) + assert cast_array.type == expected_array_type + + def test_cast_array_xd_to_features_sequence(): arr = np.random.randint(0, 10, size=(8, 2, 3)).tolist() arr = Array2DExtensionType(shape=(2, 3), dtype="int64").wrap_array(pa.array(arr, pa.list_(pa.list_(pa.int64())))) @@ -1293,6 +1349,39 @@ def test_embed_array_storage_nested(image_file): assert isinstance(embedded_images_array.to_pylist()[0]["foo"]["bytes"], bytes) +@pytest.mark.parametrize( + "array, feature, expected_embedded_array_type", + [ + ( + pa.array([[{"path": "image_path"}]], type=pa.list_(Image.pa_type)), + [Image()], + pa.types.is_list, + ), + ( + pa.array([[{"path": "image_path"}]], type=pa.large_list(Image.pa_type)), + LargeList(Image()), + pa.types.is_large_list, + ), + ( + pa.array([[{"path": "image_path"}]], type=pa.list_(Image.pa_type)), + Sequence(Image()), + pa.types.is_list, + ), + ], +) +def test_embed_array_storage_with_list_types(array, feature, expected_embedded_array_type, monkeypatch): + mock_embed_storage = MagicMock( + return_value=pa.StructArray.from_arrays( + [pa.array([b"image_bytes"], type=pa.binary()), pa.array(["image_path"], type=pa.string())], + ["bytes", "path"], + ) + ) + monkeypatch.setattr(Image, "embed_storage", mock_embed_storage) + embedded_images_array = embed_array_storage(array, feature) + assert expected_embedded_array_type(embedded_images_array.type) + assert embedded_images_array.to_pylist() == [[{"bytes": b"image_bytes", "path": "image_path"}]] + + def test_embed_table_storage(image_file): features = Features({"image": Image()}) table = table_cast(pa.table({"image": [image_file]}), features.arrow_schema) @@ -1326,10 +1415,14 @@ def test_table_iter(table, batch_size, drop_last_batch): assert table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() -@pytest.mark.parametrize("to_type", ["list", "fixed_size_list"]) -@pytest.mark.parametrize("from_type", ["list", "fixed_size_list"]) +@pytest.mark.parametrize("to_type", ["list", "fixed_size_list", "large_list"]) +@pytest.mark.parametrize("from_type", ["list", "fixed_size_list", "large_list"]) def test_array_cast(from_type, to_type): - array_type = {"list": pa.list_(pa.int64()), "fixed_size_list": pa.list_(pa.int64(), 2)} + array_type = { + "list": pa.list_(pa.int64()), + "fixed_size_list": pa.list_(pa.int64(), 2), + "large_list": pa.large_list(pa.int64()), + } arr = pa.array([[0, 1]], type=array_type[from_type]) cast_arr = array_cast(arr, array_type[to_type]) assert cast_arr.type == array_type[to_type]