Skip to content

Commit

Permalink
restore arrow array to numpy to numpy extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hh committed Oct 15, 2024
1 parent ac5a46d commit a89ef52
Showing 1 changed file with 35 additions and 36 deletions.
71 changes: 35 additions & 36 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,41 +107,6 @@ def _is_array_with_nulls(pa_array: pa.Array) -> bool:
return pa_array.null_count > 0


def _arrow_array_to_numpy(pa_array: pa.Array) -> np.ndarray:
if isinstance(pa_array, pa.ChunkedArray):
if isinstance(pa_array.type, _ArrayXDExtensionType):
# don't call to_pylist() to preserve dtype of the fixed-size array
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
else:
zero_copy_only = _is_zero_copy_only(pa_array.type) and all(
not _is_array_with_nulls(chunk) for chunk in pa_array.chunks
)
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
else:
if isinstance(pa_array.type, _ArrayXDExtensionType):
# don't call to_pylist() to preserve dtype of the fixed-size array
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only)
else:
zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array)
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist()

if len(array) > 0:
if any(
(isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape))
or (isinstance(x, float) and np.isnan(x))
for x in array
):
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
return np.asarray(array, dtype=object)
return np.array(array, copy=False, dtype=object)
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
return np.asarray(array)
else:
return np.array(array, copy=False)


def dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[str, List[T]]) -> List[Dict[str, T]]:
# convert to list of dicts
list_of_dicts = []
Expand Down Expand Up @@ -231,7 +196,41 @@ def extract_column(self, pa_table: pa.Table) -> np.ndarray:
return self._arrow_array_to_numpy(pa_table[pa_table.column_names[0]])

def extract_batch(self, pa_table: pa.Table) -> dict:
return {col: _arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}
return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}

def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray:
if isinstance(pa_array, pa.ChunkedArray):
if isinstance(pa_array.type, _ArrayXDExtensionType):
# don't call to_pylist() to preserve dtype of the fixed-size array
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
else:
zero_copy_only = _is_zero_copy_only(pa_array.type) and all(
not _is_array_with_nulls(chunk) for chunk in pa_array.chunks
)
array: List = [row for chunk in pa_array.chunks for row in chunk.to_numpy(zero_copy_only=zero_copy_only)]
else:
if isinstance(pa_array.type, _ArrayXDExtensionType):
# don't call to_pylist() to preserve dtype of the fixed-size array
zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True)
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only)
else:
zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array)
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist()

if len(array) > 0:
if any(
(isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape))
or (isinstance(x, float) and np.isnan(x))
for x in array
):
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
return np.asarray(array, dtype=object)
return np.array(array, copy=False, dtype=object)
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
return np.asarray(array)
else:
return np.array(array, copy=False)


class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]):
Expand Down

0 comments on commit a89ef52

Please sign in to comment.