diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 2de35a943e7..154f4d61072 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -278,7 +278,7 @@ def __init__( self.cols_to_retain = cols_to_retain self.collate_fn = collate_fn self.collate_fn_args = collate_fn_args - self.string_columns = [col for col, dtype in columns_to_np_types.items() if dtype in (np.unicode_, np.str_)] + self.string_columns = [col for col, dtype in columns_to_np_types.items() if dtype is np.str_] # Strings will be converted to arrays of single unicode chars, so that we can have a constant itemsize self.columns_to_np_types = { col: dtype if col not in self.string_columns else np.dtype("U1")