diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 9570fb1808c..2dae3a52fd3 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -187,14 +187,20 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray: else: zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array) array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist() + if len(array) > 0: if any( (isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape)) or (isinstance(x, float) and np.isnan(x)) for x in array ): + if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": + return np.asarray(array, dtype=object) return np.array(array, copy=False, dtype=object) - return np.array(array, copy=False) + if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": + return np.asarray(array) + else: + return np.array(array, copy=False) class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]):