diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 3ec78100abe97..e5db0bd701f1d 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -198,25 +198,18 @@ def binary_search_for_buckets(value): idx = bisect.bisect(bins, value) - 1 return float(idx) - output_df = None - for group_id, (colname, bucket_name) in enumerate(zip(colnames, bucket_names)): - # sdf.na.drop to match handleInvalid="skip" in Bucketizer - - bucket_df = sdf.na.drop(subset=[colname]).withColumn( - bucket_name, - binary_search_for_buckets(F.col(colname).cast("double")), + output_df = ( + sdf.select( + F.posexplode( + F.array([F.col(colname).cast("double") for colname in colnames]) + ).alias("__group_id", "__value") ) - - if output_df is None: - output_df = bucket_df.select( - F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket") - ) - else: - output_df = output_df.union( - bucket_df.select( - F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket") - ) - ) + # to match handleInvalid="skip" in Bucketizer + .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select( + F.col("__group_id"), + binary_search_for_buckets(F.col("__value")).alias("__bucket"), + ) + ) # 2. Calculate the count based on each group and bucket. # +----------+-------+------+