From f2422e906b8847d1ce4d63fa42e79d879dbe30a0 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 19 Apr 2023 12:18:53 +0200 Subject: [PATCH] FIX-#4828: allow `dict_apply_builder` use keyword argument `internal_indices` (#5945) Signed-off-by: Anatoly Myachev --- .../storage_formats/pandas/query_compiler.py | 17 ++++-- modin/pandas/test/dataframe/test_udf.py | 54 ++++++++++++++++--- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/modin/core/storage_formats/pandas/query_compiler.py b/modin/core/storage_formats/pandas/query_compiler.py index 4eb622d280e..0662a56311f 100644 --- a/modin/core/storage_formats/pandas/query_compiler.py +++ b/modin/core/storage_formats/pandas/query_compiler.py @@ -2700,15 +2700,24 @@ def _dict_func(self, func, axis, *args, **kwargs): if "axis" not in kwargs: kwargs["axis"] = axis - def dict_apply_builder(df, func_dict={}): # pragma: no cover + func = {k: wrap_udf_function(v) if callable(v) else v for k, v in func.items()} + + def dict_apply_builder(df, internal_indices=[]): # pragma: no cover # Sometimes `apply` can return a `Series`, but we require that internally # all objects are `DataFrame`s. - return pandas.DataFrame(df.apply(func_dict, *args, **kwargs)) + # It looks like it doesn't need to use `internal_indices` option internally + # for the case since `apply` use labels from dictionary keys in `func` variable. + return pandas.DataFrame(df.apply(func, *args, **kwargs)) - func = {k: wrap_udf_function(v) if callable(v) else v for k, v in func.items()} + labels = list(func.keys()) return self.__constructor__( self._modin_frame.apply_full_axis_select_indices( - axis, dict_apply_builder, func, keep_remaining=False + axis, + dict_apply_builder, + labels, + new_index=labels if axis == 1 else None, + new_columns=labels if axis == 0 else None, + keep_remaining=False, ) ) diff --git a/modin/pandas/test/dataframe/test_udf.py b/modin/pandas/test/dataframe/test_udf.py index a7330afaf20..566037a189c 100644 --- a/modin/pandas/test/dataframe/test_udf.py +++ b/modin/pandas/test/dataframe/test_udf.py @@ -40,7 +40,7 @@ arg_keys, default_to_pandas_ignore_string, ) -from modin.config import NPartitions, StorageFormat +from modin.config import NPartitions from modin.test.test_utils import warns_that_defaulting_to_pandas from modin.utils import get_current_execution @@ -116,13 +116,6 @@ def test_aggregate_error_checking(): modin_df.aggregate("NOT_EXISTS") -@pytest.mark.xfail( - StorageFormat.get() == "Pandas", - reason="DataFrame.apply(dict) raises an exception because of a bug in its" - + "implementation for pandas storage format, this prevents us from catching the desired" - + "exception. You can track this bug at:" - + "https://github.com/modin-project/modin/issues/3221", -) @pytest.mark.parametrize( "func", agg_func_values + agg_func_except_values, @@ -245,6 +238,51 @@ def test_apply_udf(data, func): ) +def test_apply_dict_4828(): + data = [[2, 4], [1, 3]] + modin_df1, pandas_df1 = create_test_dfs(data) + eval_general( + modin_df1, + pandas_df1, + lambda df: df.apply({0: (lambda x: x**2)}), + ) + eval_general( + modin_df1, + pandas_df1, + lambda df: df.apply({0: (lambda x: x**2)}, axis=1), + ) + + # several partitions along axis 0 + modin_df2, pandas_df2 = create_test_dfs(data, index=[2, 3]) + modin_df3 = pd.concat([modin_df1, modin_df2], axis=0) + pandas_df3 = pandas.concat([pandas_df1, pandas_df2], axis=0) + eval_general( + modin_df3, + pandas_df3, + lambda df: df.apply({0: (lambda x: x**2)}), + ) + eval_general( + modin_df3, + pandas_df3, + lambda df: df.apply({0: (lambda x: x**2)}, axis=1), + ) + + # several partitions along axis 1 + modin_df4, pandas_df4 = create_test_dfs(data, columns=[2, 3]) + modin_df5 = pd.concat([modin_df1, modin_df4], axis=1) + pandas_df5 = pandas.concat([pandas_df1, pandas_df4], axis=1) + eval_general( + modin_df5, + pandas_df5, + lambda df: df.apply({0: (lambda x: x**2)}), + ) + eval_general( + modin_df5, + pandas_df5, + lambda df: df.apply({0: (lambda x: x**2)}, axis=1), + ) + + def test_apply_modin_func_4635(): data = [1] modin_df, pandas_df = create_test_dfs(data)