From 0c18fc072b05671bc9c74a43de49b563a1ef7907 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 24 Aug 2024 16:34:48 +0800 Subject: [PATCH] [SPARK-49365][PS] Simplify the bucket aggregation in hist plot ### What changes were proposed in this pull request? Simplify the bucket aggregation in hist plot ### Why are the changes needed? to simplify the implementation, by eliminating the multiple dataframes Union ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI and manually check ### Was this patch authored or co-authored using generative AI tooling? No Closes #47852 from zhengruifeng/plot_parallel_hist. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/plot/core.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 3ec78100abe97..e5db0bd701f1d 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -198,25 +198,18 @@ def binary_search_for_buckets(value): idx = bisect.bisect(bins, value) - 1 return float(idx) - output_df = None - for group_id, (colname, bucket_name) in enumerate(zip(colnames, bucket_names)): - # sdf.na.drop to match handleInvalid="skip" in Bucketizer - - bucket_df = sdf.na.drop(subset=[colname]).withColumn( - bucket_name, - binary_search_for_buckets(F.col(colname).cast("double")), + output_df = ( + sdf.select( + F.posexplode( + F.array([F.col(colname).cast("double") for colname in colnames]) + ).alias("__group_id", "__value") ) - - if output_df is None: - output_df = bucket_df.select( - F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket") - ) - else: - output_df = output_df.union( - bucket_df.select( - F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket") - ) - ) + # to match handleInvalid="skip" in Bucketizer + .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select( + F.col("__group_id"), + binary_search_for_buckets(F.col("__value")).alias("__bucket"), + ) + ) # 2. Calculate the count based on each group and bucket. # +----------+-------+------+