diff --git a/src/datasets/table.py b/src/datasets/table.py index bfa63173ef9..f18e402786e 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2000,10 +2000,14 @@ def cast_array_to_feature( sequence_kwargs = vars(feature).copy() feature = sequence_kwargs.pop("feature") feature = {name: Sequence(subfeature, **sequence_kwargs) for name, subfeature in feature.items()} - if isinstance(feature, dict) and {field.name for field in array.type} == set(feature): + if isinstance(feature, dict) and (array_fields := {field.name for field in array.type}) <= set(feature): if array.type.num_fields == 0: return array - arrays = [_c(array.field(name), subfeature) for name, subfeature in feature.items()] + null_array = pa.array([None] * len(array)) + arrays = [ + _c(array.field(name) if name in array_fields else null_array, subfeature) + for name, subfeature in feature.items() + ] return pa.StructArray.from_arrays(arrays, names=list(feature), mask=array.is_null()) elif pa.types.is_list(array.type) or pa.types.is_large_list(array.type): # feature must be either [subfeature] or LargeList(subfeature) or Sequence(subfeature) diff --git a/tests/test_table.py b/tests/test_table.py index c624f98d3f0..2c0cdd1a9b6 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1142,6 +1142,14 @@ def test_cast_decimal_array_to_features(): cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False) +def test_cast_array_to_features_with_struct_with_missing_fields(): + arr = pa.array([{"age": 25}, {"age": 63}]) + feature = {"age": Value("int32"), "name": Value("string")} + cast_array = cast_array_to_feature(arr, feature) + assert cast_array.type == pa.struct({"age": pa.int32(), "name": pa.string()}) + assert cast_array.to_pylist() == [{"age": 25, "name": None}, {"age": 63, "name": None}] + + def test_cast_array_to_features_nested(): arr = pa.array([[{"foo": [0]}]]) assert cast_array_to_feature(arr, [{"foo": Sequence(Value("string"))}]).type == pa.list_(