Skip to content

Commit

Permalink
Fix cast from fixed size list to variable size list (#6243)
Browse files Browse the repository at this point in the history
* Fixed cast from fixed size list to variable size list

* Style

* Style again
  • Loading branch information
mariosasko authored Sep 19, 2023
1 parent 9b21e18 commit 05fe5c0
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _convert_to_arrow(
iterator = iter(iterable)
for key, example in iterator:
iterator_batch = islice(iterator, batch_size - 1)
key_examples_list = [(key, example)] + [(key, example) for key, example in iterator_batch]
key_examples_list = [(key, example)] + list(iterator_batch)
if len(key_examples_list) < batch_size and drop_last_batch:
return
keys, examples = zip(*key_examples_list)
Expand Down Expand Up @@ -697,7 +697,7 @@ def _iter(self):
if self.batch_size is None or self.batch_size <= 0
else islice(iterator, self.batch_size - 1)
)
key_examples_list = [(key, example)] + [(key, example) for key, example in iterator_batch]
key_examples_list = [(key, example)] + list(iterator_batch)
keys, examples = zip(*key_examples_list)
if (
self.drop_last_batch
Expand Down Expand Up @@ -880,7 +880,7 @@ def _iter(self):
if self.batch_size is None or self.batch_size <= 0
else islice(iterator, self.batch_size - 1)
)
key_examples_list = [(key, example)] + [(key, example) for key, example in iterator_batch]
key_examples_list = [(key, example)] + list(iterator_batch)
keys, examples = zip(*key_examples_list)
batch = _examples_to_batch(examples)
batch = format_dict(batch) if format_dict else batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
datasets.SplitGenerator(
name=split_name,
gen_kwargs={
"files": [(file, downloaded_file) for file, downloaded_file in zip(files, downloaded_files)]
"files": list(zip(files, downloaded_files))
+ [(None, dl_manager.iter_files(downloaded_dir)) for downloaded_dir in downloaded_dirs],
"metadata_files": metadata_files,
"split_name": split_name,
Expand Down
7 changes: 4 additions & 3 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2002,7 +2002,7 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):
pa_type.list_size,
)
elif pa.types.is_list(pa_type):
offsets_arr = pa.array(range(len(array) + 1), pa.int32())
offsets_arr = pa.array(np.arange(len(array) + 1) * array.type.list_size, pa.int32())
if array.null_count > 0:
if config.PYARROW_VERSION.major < 10:
warnings.warn(
Expand Down Expand Up @@ -2061,6 +2061,7 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
array = array.storage
if hasattr(feature, "cast_storage"):
return feature.cast_storage(array)

elif pa.types.is_struct(array.type):
# feature must be a dict or Sequence(subfeatures_dict)
if isinstance(feature, Sequence) and isinstance(feature.feature, dict):
Expand Down Expand Up @@ -2126,7 +2127,7 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
if feature.length * len(array) == len(array_values):
return pa.FixedSizeListArray.from_arrays(_c(array_values, feature.feature), feature.length)
else:
offsets_arr = pa.array(range(len(array) + 1), pa.int32())
offsets_arr = pa.array(np.arange(len(array) + 1) * array.type.list_size, pa.int32())
if array.null_count > 0:
if config.PYARROW_VERSION.major < 10:
warnings.warn(
Expand Down Expand Up @@ -2233,7 +2234,7 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"):
if feature.length * len(array) == len(array_values):
return pa.FixedSizeListArray.from_arrays(_e(array_values, feature.feature), feature.length)
else:
offsets_arr = pa.array(range(len(array) + 1), pa.int32())
offsets_arr = pa.array(np.arange(len(array) + 1) * array.type.list_size, pa.int32())
if array.null_count > 0:
if config.PYARROW_VERSION.major < 10:
warnings.warn(
Expand Down
6 changes: 2 additions & 4 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_convert_to_arrow(batch_size, drop_last_batch):
num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
subtables = list(
_convert_to_arrow(
[(i, example) for i, example in enumerate(examples)],
list(enumerate(examples)),
batch_size=batch_size,
drop_last_batch=drop_last_batch,
)
Expand Down Expand Up @@ -162,9 +162,7 @@ def test_batch_arrow_tables(tables, batch_size, drop_last_batch):
num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size
num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
subtables = list(
_batch_arrow_tables(
[(i, table) for i, table in enumerate(tables)], batch_size=batch_size, drop_last_batch=drop_last_batch
)
_batch_arrow_tables(list(enumerate(tables)), batch_size=batch_size, drop_last_batch=drop_last_batch)
)
assert len(subtables) == num_batches
if drop_last_batch:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,18 @@ def test_cast_array_to_features_sequence_classlabel():
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"])))


def test_cast_fixed_size_array_to_features_sequence():
arr = pa.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]], pa.list_(pa.int32(), 3))
# Fixed size list
casted_array = cast_array_to_feature(arr, Sequence(Value("int64"), length=3))
assert casted_array.type == pa.list_(pa.int64(), 3)
assert casted_array.to_pylist() == arr.to_pylist()
# Variable size list
casted_array = cast_array_to_feature(arr, Sequence(Value("int64")))
assert casted_array.type == pa.list_(pa.int64())
assert casted_array.to_pylist() == arr.to_pylist()


def test_cast_sliced_fixed_size_array_to_features():
arr = pa.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]], pa.list_(pa.int32(), 3))
casted_array = cast_array_to_feature(arr[1:], Sequence(Value("int64"), length=3))
Expand Down

0 comments on commit 05fe5c0

Please sign in to comment.