Skip to content

Commit

Permalink
add allow_primitive_to_str and allow_decimal_to_str instead of allow_…
Browse files Browse the repository at this point in the history
…number_to_str (#6811)

* add allow_primitive_to_str and allow_decimal_to_str instead of allow_number_to_str

* add missing allow_decimal_str and split typerrors

* Style

---------

Co-authored-by: Mario Šaško <mariosasko777@gmail.com>
  • Loading branch information
Modexus and mariosasko authored Apr 16, 2024
1 parent a188022 commit 8983a3b
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 28 deletions.
10 changes: 7 additions & 3 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
# We use cast_array_to_feature to support casting to custom types like Audio and Image
# Also, when trying type "string", we don't want to convert integers or floats to "string".
# We only do it if trying_type is False - since this is what the user asks for.
out = cast_array_to_feature(out, type, allow_number_to_str=not self.trying_type)
out = cast_array_to_feature(
out, type, allow_primitive_to_str=not self.trying_type, allow_decimal_to_str=not self.trying_type
)
return out
except (
TypeError,
Expand Down Expand Up @@ -241,7 +243,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False)
)
if type is not None:
out = cast_array_to_feature(out, type, allow_number_to_str=True)
out = cast_array_to_feature(
out, type, allow_primitive_to_str=True, allow_decimal_to_str=True
)
return out
else:
raise
Expand All @@ -256,7 +260,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
elif trying_cast_to_python_objects and "Could not convert" in str(e):
out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True, optimize_list_casting=False))
if type is not None:
out = cast_array_to_feature(out, type, allow_number_to_str=True)
out = cast_array_to_feature(out, type, allow_primitive_to_str=True, allow_decimal_to_str=True)
return out
else:
raise
Expand Down
86 changes: 63 additions & 23 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,33 +1838,40 @@ def _storage_type(type: pa.DataType) -> pa.DataType:


@_wrap_for_chunked_arrays
def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):
def array_cast(
array: pa.Array, pa_type: pa.DataType, allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True
) -> Union[pa.Array, pa.FixedSizeListArray, pa.ListArray, pa.StructArray, pa.ExtensionArray]:
"""Improved version of `pa.Array.cast`
It supports casting `pa.StructArray` objects to re-order the fields.
It also let you control certain aspects of the casting, e.g. whether
to disable numbers (`floats` or `ints`) to strings.
to disable casting primitives (`booleans`, `floats` or `ints`) or
disable casting decimals to strings.
Args:
array (`pa.Array`):
PyArrow array to cast
pa_type (`pa.DataType`):
Target PyArrow type
allow_number_to_str (`bool`, defaults to `True`):
Whether to allow casting numbers to strings.
allow_primitive_to_str (`bool`, defaults to `True`):
Whether to allow casting primitives to strings.
Defaults to `True`.
allow_decimal_to_str (`bool`, defaults to `True`):
Whether to allow casting decimals to strings.
Defaults to `True`.
Raises:
`pa.ArrowInvalidError`: if the arrow data casting fails
`TypeError`: if the target type is not supported according, e.g.
- if a field is missing
- if casting from numbers to strings and `allow_number_to_str` is `False`
- if casting from primitives to strings and `allow_primitive_to_str` is `False`
- if casting from decimals to strings and `allow_decimal_to_str` is `False`
Returns:
`List[pyarrow.Array]`: the casted array
"""
_c = partial(array_cast, allow_number_to_str=allow_number_to_str)
_c = partial(array_cast, allow_primitive_to_str=allow_primitive_to_str, allow_decimal_to_str=allow_decimal_to_str)
if isinstance(array, pa.ExtensionArray):
array = array.storage
if isinstance(pa_type, pa.ExtensionType):
Expand Down Expand Up @@ -1933,22 +1940,27 @@ def array_cast(array: pa.Array, pa_type: pa.DataType, allow_number_to_str=True):
array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size
return pa.ListArray.from_arrays(array_offsets, _c(array.values, pa_type.value_type), mask=array.is_null())
else:
if (
not allow_number_to_str
and pa.types.is_string(pa_type)
and (pa.types.is_floating(array.type) or pa.types.is_integer(array.type))
):
raise TypeError(
f"Couldn't cast array of type {array.type} to {pa_type} since allow_number_to_str is set to {allow_number_to_str}"
)
if pa.types.is_string(pa_type):
if not allow_primitive_to_str and pa.types.is_primitive(array.type):
raise TypeError(
f"Couldn't cast array of type {array.type} to {pa_type} "
f"since allow_primitive_to_str is set to {allow_primitive_to_str} "
)
if not allow_decimal_to_str and pa.types.is_decimal(array.type):
raise TypeError(
f"Couldn't cast array of type {array.type} to {pa_type} "
f"and allow_decimal_to_str is set to {allow_decimal_to_str}"
)
if pa.types.is_null(pa_type) and not pa.types.is_null(array.type):
raise TypeError(f"Couldn't cast array of type {array.type} to {pa_type}")
return array.cast(pa_type)
raise TypeError(f"Couldn't cast array of type\n{array.type}\nto\n{pa_type}")


@_wrap_for_chunked_arrays
def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_to_str=True):
def cast_array_to_feature(
array: pa.Array, feature: "FeatureType", allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True
) -> pa.Array:
"""Cast an array to the arrow type that corresponds to the requested feature type.
For custom features like [`Audio`] or [`Image`], it takes into account the "cast_storage" methods
they defined to enable casting from other arrow types.
Expand All @@ -1958,23 +1970,31 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
The PyArrow array to cast.
feature (`datasets.features.FeatureType`):
The target feature type.
allow_number_to_str (`bool`, defaults to `True`):
Whether to allow casting numbers to strings.
allow_primitive_to_str (`bool`, defaults to `True`):
Whether to allow casting primitives to strings.
Defaults to `True`.
allow_decimal_to_str (`bool`, defaults to `True`):
Whether to allow casting decimals to strings.
Defaults to `True`.
Raises:
`pa.ArrowInvalidError`: if the arrow data casting fails
`TypeError`: if the target type is not supported according, e.g.
- if a field is missing
- if casting from numbers to strings and `allow_number_to_str` is `False`
- if casting from primitives and `allow_primitive_to_str` is `False`
- if casting from decimals and `allow_decimal_to_str` is `False`
Returns:
array (`pyarrow.Array`): the casted array
"""
from .features.features import Sequence, get_nested_type

_c = partial(cast_array_to_feature, allow_number_to_str=allow_number_to_str)
_c = partial(
cast_array_to_feature,
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)

if isinstance(array, pa.ExtensionArray):
array = array.storage
Expand Down Expand Up @@ -2011,9 +2031,19 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
storage_type = _storage_type(array_type)
if array_type != storage_type:
# Temporarily convert to the storage type to support extension types in the slice operation
array = array_cast(array, storage_type, allow_number_to_str=allow_number_to_str)
array = array_cast(
array,
storage_type,
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
array = pc.list_slice(array, 0, feature.length, return_fixed_size_list=True)
array = array_cast(array, array_type, allow_number_to_str=allow_number_to_str)
array = array_cast(
array,
array_type,
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
else:
array = pc.list_slice(array, 0, feature.length, return_fixed_size_list=True)
array_values = array.values
Expand Down Expand Up @@ -2069,9 +2099,19 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_
array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size
return pa.ListArray.from_arrays(array_offsets, _c(array.values, feature.feature), mask=array.is_null())
if pa.types.is_null(array.type):
return array_cast(array, get_nested_type(feature), allow_number_to_str=allow_number_to_str)
return array_cast(
array,
get_nested_type(feature),
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
elif not isinstance(feature, (Sequence, dict, list, tuple)):
return array_cast(array, feature(), allow_number_to_str=allow_number_to_str)
return array_cast(
array,
feature(),
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
raise TypeError(f"Couldn't cast array of type\n{array.type}\nto\n{feature}")


Expand Down
38 changes: 36 additions & 2 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import pickle
from decimal import Decimal
from typing import List, Union

import numpy as np
Expand Down Expand Up @@ -1098,11 +1099,44 @@ def test_indexed_table_mixin():
assert table.fast_slice(2, 13) == pa_table.slice(2, 13)


def test_cast_array_to_features():
def test_cast_integer_array_to_features():
arr = pa.array([[0, 1]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False)


def test_cast_float_array_to_features():
arr = pa.array([[0.0, 1.0]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False)


def test_cast_boolean_array_to_features():
arr = pa.array([[False, True]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False)


def test_cast_decimal_array_to_features():
arr = pa.array([[Decimal(0), Decimal(1)]])
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
assert cast_array_to_feature(arr, Sequence(Value("string")), allow_primitive_to_str=False).type == pa.list_(
pa.string()
)
with pytest.raises(TypeError):
cast_array_to_feature(arr, Sequence(Value("string")), allow_number_to_str=False)
cast_array_to_feature(arr, Sequence(Value("string")), allow_decimal_to_str=False)


def test_cast_array_to_features_nested():
Expand Down

0 comments on commit 8983a3b

Please sign in to comment.