Skip to content

Commit

Permalink
PERF-modin-project#5296: Partition parquet file if it has too few row…
Browse files Browse the repository at this point in the history
… groups (modin-project#7016)

Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev authored Mar 8, 2024
1 parent 14452a8 commit b1501d8
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 5 deletions.
110 changes: 109 additions & 1 deletion modin/core/io/column_stores/parquet_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from packaging import version
from pandas.io.common import stringify_path

from modin.config import NPartitions
from modin.config import MinPartitionSize, NPartitions
from modin.core.io.column_stores.column_store_dispatcher import ColumnStoreDispatcher
from modin.error_message import ErrorMessage
from modin.utils import _inherit_docstrings
Expand Down Expand Up @@ -646,6 +646,110 @@ def build_index(cls, dataset, partition_ids, index_columns, filters):
complete_index = index_objs[0].append(index_objs[1:])
return complete_index, range_index or (len(index_columns) == 0)

@classmethod
def _normalize_partitioning(cls, remote_parts, row_lengths, column_widths):
"""
Normalize partitioning according to the default partitioning scheme in Modin.
The result of 'read_parquet()' is often under partitioned over rows and over partitioned
over columns, so this method expands the number of row splits and shrink the number of column splits.
Parameters
----------
remote_parts : np.ndarray
row_lengths : list of ints or None
Row lengths, if 'None', won't repartition across rows.
column_widths : list of ints
Returns
-------
remote_parts : np.ndarray
row_lengths : list of ints or None
column_widths : list of ints
"""
if len(remote_parts) == 0:
return remote_parts, row_lengths, column_widths

from modin.core.storage_formats.pandas.utils import get_length_list

# The code in this function is actually a duplication of what 'BaseQueryCompiler.repartition()' does,
# however this implementation works much faster for some reason

actual_row_nparts = remote_parts.shape[0]

if row_lengths is not None:
desired_row_nparts = max(
1, min(sum(row_lengths) // MinPartitionSize.get(), NPartitions.get())
)
else:
desired_row_nparts = actual_row_nparts

# only repartition along rows if the actual number of row splits 1.5 times SMALLER than desired
if 1.5 * actual_row_nparts < desired_row_nparts:
# assuming that the sizes of parquet's row groups are more or less equal,
# so trying to use the same number of splits for each partition
splits_per_partition = desired_row_nparts // actual_row_nparts
remainder = desired_row_nparts % actual_row_nparts

new_parts = []
new_row_lengths = []

for row_idx, (part_len, row_parts) in enumerate(
zip(row_lengths, remote_parts)
):
num_splits = splits_per_partition
# 'remainder' indicates how many partitions have to be split into 'num_splits + 1' splits
# to have exactly 'desired_row_nparts' in the end
if row_idx < remainder:
num_splits += 1

if num_splits == 1:
new_parts.append(row_parts)
new_row_lengths.append(part_len)
continue

offset = len(new_parts)
# adding empty row parts according to the number of splits
new_parts.extend([[] for _ in range(num_splits)])
for part in row_parts:
split = cls.frame_cls._partition_mgr_cls._column_partitions_class(
[part]
).apply(
lambda df: df,
num_splits=num_splits,
maintain_partitioning=False,
)
for i in range(num_splits):
new_parts[offset + i].append(split[i])

new_row_lengths.extend(get_length_list(part_len, num_splits))

remote_parts = np.array(new_parts)
row_lengths = new_row_lengths

desired_col_nparts = max(
1, min(sum(column_widths) // MinPartitionSize.get(), NPartitions.get())
)
# only repartition along cols if the actual number of col splits 1.5 times BIGGER than desired
if 1.5 * desired_col_nparts < remote_parts.shape[1]:
remote_parts = np.array(
[
(
cls.frame_cls._partition_mgr_cls._row_partition_class(
row_parts
).apply(
lambda df: df,
num_splits=desired_col_nparts,
maintain_partitioning=False,
)
)
for row_parts in remote_parts
]
)
column_widths = get_length_list(sum(column_widths), desired_col_nparts)

return remote_parts, row_lengths, column_widths

@classmethod
def build_query_compiler(cls, dataset, columns, index_columns, **kwargs):
"""
Expand Down Expand Up @@ -687,6 +791,10 @@ def build_query_compiler(cls, dataset, columns, index_columns, **kwargs):
else:
row_lengths = None

remote_parts, row_lengths, column_widths = cls._normalize_partitioning(
remote_parts, row_lengths, column_widths
)

if (
dataset.pandas_metadata
and "column_indexes" in dataset.pandas_metadata
Expand Down
16 changes: 12 additions & 4 deletions modin/core/storage_formats/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def compute_chunksize(axis_len, num_splits, min_block_size=None):
return max(chunksize, min_block_size)


def split_result_of_axis_func_pandas(axis, num_splits, result, length_list=None):
def split_result_of_axis_func_pandas(
axis, num_splits, result, length_list=None, min_block_size=None
):
"""
Split pandas DataFrame evenly based on the provided number of splits.
Expand All @@ -73,6 +75,9 @@ def split_result_of_axis_func_pandas(axis, num_splits, result, length_list=None)
length_list : list of ints, optional
List of slice lengths to split DataFrame into. This is used to
return the DataFrame to its original partitioning schema.
min_block_size : int, optional
Minimum number of rows/columns in a single split.
If not specified, the value is assumed equal to ``MinPartitionSize``.
Returns
-------
Expand All @@ -83,7 +88,7 @@ def split_result_of_axis_func_pandas(axis, num_splits, result, length_list=None)
return [result]

if length_list is None:
length_list = get_length_list(result.shape[axis], num_splits)
length_list = get_length_list(result.shape[axis], num_splits, min_block_size)
# Inserting the first "zero" to properly compute cumsum indexing slices
length_list = np.insert(length_list, obj=0, values=[0])

Expand All @@ -109,7 +114,7 @@ def split_result_of_axis_func_pandas(axis, num_splits, result, length_list=None)
]


def get_length_list(axis_len: int, num_splits: int) -> list:
def get_length_list(axis_len: int, num_splits: int, min_block_size=None) -> list:
"""
Compute partitions lengths along the axis with the specified number of splits.
Expand All @@ -119,13 +124,16 @@ def get_length_list(axis_len: int, num_splits: int) -> list:
Element count in an axis.
num_splits : int
Number of splits along the axis.
min_block_size : int, optional
Minimum number of rows/columns in a single split.
If not specified, the value is assumed equal to ``MinPartitionSize``.
Returns
-------
list of ints
List of integer lengths of partitions.
"""
chunksize = compute_chunksize(axis_len, num_splits)
chunksize = compute_chunksize(axis_len, num_splits, min_block_size)
return [
(
chunksize
Expand Down
68 changes: 68 additions & 0 deletions modin/pandas/test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
TestReadFromPostgres,
TestReadFromSqlServer,
)
from modin.core.storage_formats.pandas.utils import split_result_of_axis_func_pandas
from modin.db_conn import ModinDatabaseConnection, UnsupportedDatabaseException
from modin.pandas.io import from_arrow, from_ray_dataset, to_pandas
from modin.test.test_utils import warns_that_defaulting_to_pandas
Expand Down Expand Up @@ -2053,6 +2054,73 @@ def test_read_parquet_s3_with_column_partitioning(
storage_options=s3_storage_options,
)

@pytest.mark.skipif(
StorageFormat.get() != "Pandas",
reason="Doesn't make sense for executions that do not use Modin's partitioning",
)
@pytest.mark.skipif(
Engine.get() == "Python",
reason="Python engine uses default-to-pandas implementations and has another logic of partitioning",
)
@pytest.mark.parametrize(
"modify_config", [{NPartitions: 16, MinPartitionSize: 32}], indirect=True
)
@pytest.mark.parametrize("parquet_num_row_groups", [1, 2, 7, 8, 11, 16, 23, 32])
@pytest.mark.parametrize("ncols", [2, 10, 34])
def test_read_parquet_proper_partitioning(
self, modify_config, ncols, parquet_num_row_groups, tmp_path, engine
):
"""
Test that no matter how the original parquet file is partitioned,
the resulted modin dataframe has proper partitioning.
"""
nrows = 1024
test_df = pandas.DataFrame(
{
**{f"data_col{i}": np.arange(nrows) for i in range(ncols)},
}
)
path = tmp_path / "data"
path.mkdir()
parts = split_result_of_axis_func_pandas(
axis=0, num_splits=parquet_num_row_groups, result=test_df, min_block_size=1
)
for i, part in enumerate(parts):
part.to_parquet(
path / f"parquet_part{i}.parquet",
engine=engine,
)

md_df = pd.read_parquet(path)

expected_num_rows = max(
1, min(nrows // MinPartitionSize.get(), NPartitions.get())
)
expected_num_rows = (
expected_num_rows
if parquet_num_row_groups * 1.5 < expected_num_rows
else parquet_num_row_groups
)

expected_num_cols = max(
1, min(ncols // MinPartitionSize.get(), NPartitions.get())
)
expected_num_cols = (
# the repartition logic EXPANDS the number of row splits and SHRINKS the number
# of col splits, that's why we're applying '1.5' multiplier to different variables,
# (apply multiplier to 'expected_*' for cols and to 'actual_*' for rows)
expected_num_cols
if expected_num_cols * 1.5 < ncols
else ncols
)

assert md_df._query_compiler._modin_frame._partitions.shape[0] == min(
expected_num_rows, NPartitions.get()
)
assert md_df._query_compiler._modin_frame._partitions.shape[1] == min(
expected_num_cols, NPartitions.get()
)


# Leave this test apart from the test classes, which skip the default to pandas
# warning check. We want to make sure we are NOT defaulting to pandas for a
Expand Down

0 comments on commit b1501d8

Please sign in to comment.