Skip to content

Commit

Permalink
[SPARK-48901][SPARK-48916][SS][PYTHON] Introduce clusterBy DataStream…
Browse files Browse the repository at this point in the history
…Writer API

### What changes were proposed in this pull request?
Introduce a new clusterBy DataStreamWriter API in Scala, Spark Connect, and Pyspark.

### Why are the changes needed?
Provides another way for users to create clustered tables for streaming writes.

### Does this PR introduce _any_ user-facing change?
Yes, it adds a new clusterBy DataStreamWriter API in Scala, Spark Connect, and Pyspark to allow specifying the clustering columns when writing streaming DataFrames.

### How was this patch tested?
See new unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47376 from chirag-s-db/cluster-by-stream.

Lead-authored-by: Chirag Singh <chirag.singh@databricks.com>
Co-authored-by: Chirag Singh <137233133+chirag-s-db@users.noreply.github.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
2 people authored and HeartSaVioR committed Jul 29, 2024
1 parent 112a52d commit efc6a75
Show file tree
Hide file tree
Showing 13 changed files with 398 additions and 74 deletions.
3 changes: 3 additions & 0 deletions connect/common/src/main/protobuf/spark/connect/commands.proto
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ message WriteStreamOperationStart {

StreamingForeachFunction foreach_writer = 13;
StreamingForeachFunction foreach_batch = 14;

// (Optional) Columns used for clustering the table.
repeated string clustering_column_names = 15;
}

message StreamingForeachFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3139,6 +3139,10 @@ class SparkConnectPlanner(
writer.partitionBy(writeOp.getPartitioningColumnNamesList.asScala.toList: _*)
}

if (writeOp.getClusteringColumnNamesCount > 0) {
writer.clusterBy(writeOp.getClusteringColumnNamesList.asScala.toList: _*)
}

writeOp.getTriggerCase match {
case TriggerCase.PROCESSING_TIME_INTERVAL =>
writer.trigger(Trigger.ProcessingTime(writeOp.getProcessingTimeInterval))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,23 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
this
}

/**
* Clusters the output by the given columns. If specified, the output is laid out such that
* records with similar values on the clustering column are grouped together in the same file.
*
* Clustering improves query efficiency by allowing queries with predicates on the clustering
* columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high
* cardinality columns.
*
* @since 4.0.0
*/
@scala.annotation.varargs
def clusterBy(colNames: String*): DataStreamWriter[T] = {
sinkBuilder.clearClusteringColumnNames()
sinkBuilder.addAllClusteringColumnNames(colNames.asJava)
this
}

/**
* Adds an output option for the underlying data source.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,42 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
}
}

test("clusterBy") {
withSQLConf(
"spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
) {
spark.sql("DROP TABLE IF EXISTS my_table").collect()

withTempPath { ckpt =>
val q1 = spark.readStream
.format("rate")
.load()
.writeStream
.clusterBy("value")
.option("checkpointLocation", ckpt.getCanonicalPath)
.toTable("my_table")

try {
q1.processAllAvailable()
eventually(timeout(30.seconds)) {
checkAnswer(
spark.sql("DESCRIBE my_table"),
Seq(
Row("timestamp", "timestamp", null),
Row("value", "bigint", null),
Row("# Clustering Information", "", ""),
Row("# col_name", "data_type", "comment"),
Row("value", "bigint", null)))
assert(spark.table("my_sink").count() > 0)
}
} finally {
q1.stop()
spark.sql("DROP TABLE my_table")
}
}
}
}

test("throw exception in streaming") {
try {
val session = spark
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.session"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext"),
// SPARK-48761: Add clusterBy() to CreateTableWriter.
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.CreateTableWriter.clusterBy")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.CreateTableWriter.clusterBy"),
// SPARK-48901: Add clusterBy() to DataStreamWriter.
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy")
)

// Default exclude rules
Expand Down
136 changes: 68 additions & 68 deletions python/pyspark/sql/connect/proto/commands_pb2.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions python/pyspark/sql/connect/proto/commands_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
TABLE_NAME_FIELD_NUMBER: builtins.int
FOREACH_WRITER_FIELD_NUMBER: builtins.int
FOREACH_BATCH_FIELD_NUMBER: builtins.int
CLUSTERING_COLUMN_NAMES_FIELD_NUMBER: builtins.int
@property
def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
"""(Required) The output of the `input` streaming relation will be written."""
Expand Down Expand Up @@ -932,6 +933,11 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
def foreach_writer(self) -> global___StreamingForeachFunction: ...
@property
def foreach_batch(self) -> global___StreamingForeachFunction: ...
@property
def clustering_column_names(
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""(Optional) Columns used for clustering the table."""
def __init__(
self,
*,
Expand All @@ -949,6 +955,7 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
table_name: builtins.str = ...,
foreach_writer: global___StreamingForeachFunction | None = ...,
foreach_batch: global___StreamingForeachFunction | None = ...,
clustering_column_names: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self,
Expand Down Expand Up @@ -982,6 +989,8 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"available_now",
b"available_now",
"clustering_column_names",
b"clustering_column_names",
"continuous_checkpoint_interval",
b"continuous_checkpoint_interval",
"foreach_batch",
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,25 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]

partitionBy.__doc__ = PySparkDataStreamWriter.partitionBy.__doc__

@overload
def clusterBy(self, *cols: str) -> "DataStreamWriter":
...

@overload
def clusterBy(self, __cols: List[str]) -> "DataStreamWriter":
...

def clusterBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
# Clear any existing columns (if any).
while len(self._write_proto.clustering_column_names) > 0:
self._write_proto.clustering_column_names.pop()
self._write_proto.clustering_column_names.extend(cast(List[str], cols))
return self

clusterBy.__doc__ = PySparkDataStreamWriter.clusterBy.__doc__

def queryName(self, queryName: str) -> "DataStreamWriter":
if not queryName or type(queryName) != str or len(queryName.strip()) == 0:
raise PySparkValueError(
Expand Down
59 changes: 59 additions & 0 deletions python/pyspark/sql/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,65 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]
self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
return self

@overload
def clusterBy(self, *cols: str) -> "DataStreamWriter":
...

@overload
def clusterBy(self, __cols: List[str]) -> "DataStreamWriter":
...

def clusterBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]
"""Clusters the output by the given columns.
If specified, the output is laid out such that records with similar values on the clustering
column(s) are grouped together in the same file.
Clustering improves query efficiency by allowing queries with predicates on the clustering
columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high
cardinality columns.
.. versionadded:: 4.0.0
Parameters
----------
cols : str or list
name of columns
Notes
-----
This API is evolving.
Examples
--------
>>> df = spark.readStream.format("rate").load()
>>> df.writeStream.clusterBy("value")
<...streaming.readwriter.DataStreamWriter object ...>
Cluster-by timestamp column from Rate source.
>>> import tempfile
>>> import time
>>> with tempfile.TemporaryDirectory(prefix="partitionBy1") as d:
... with tempfile.TemporaryDirectory(prefix="partitionBy2") as cp:
... df = spark.readStream.format("rate").option("rowsPerSecond", 10).load()
... q = df.writeStream.clusterBy(
... "timestamp").format("parquet").option("checkpointLocation", cp).start(d)
... time.sleep(5)
... q.stop()
... spark.read.schema(df.schema).parquet(d).show()
+...---------+-----+
|...timestamp|value|
+...---------+-----+
...
"""
from pyspark.sql.classic.column import _to_seq

if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
self._jwrite = self._jwrite.clusterBy(_to_seq(self._spark._sc, cols))
return self

def queryName(self, queryName: str) -> "DataStreamWriter":
"""Specifies the name of the :class:`StreamingQuery` that can be started with
:func:`start`. This name must be unique among all the currently active queries
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,31 @@ def test_streaming_write_to_table(self):
result = self.spark.sql("SELECT value FROM output_table").collect()
self.assertTrue(len(result) > 0)

def test_streaming_write_to_table_cluster_by(self):
with self.table("output_table"), tempfile.TemporaryDirectory(prefix="to_table") as tmpdir:
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
q = df.writeStream.clusterBy("value").toTable(
"output_table", format="parquet", checkpointLocation=tmpdir
)
self.assertTrue(q.isActive)
time.sleep(10)
q.stop()
result = self.spark.sql("DESCRIBE output_table").collect()
self.assertEqual(
set(
[
Row(col_name="timestamp", data_type="timestamp", comment=None),
Row(col_name="value", data_type="bigint", comment=None),
Row(col_name="# Clustering Information", data_type="", comment=""),
Row(col_name="# col_name", data_type="data_type", comment="comment"),
Row(col_name="value", data_type="bigint", comment=None),
]
),
set(result),
)
result = self.spark.sql("SELECT value FROM output_table").collect()
self.assertTrue(len(result) > 0)

def test_streaming_with_temporary_view(self):
"""
This verifies createOrReplaceTempView() works with a streaming dataframe. An SQL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2}
import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
import org.apache.spark.sql.execution.streaming._
Expand Down Expand Up @@ -166,6 +167,24 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
@scala.annotation.varargs
def partitionBy(colNames: String*): DataStreamWriter[T] = {
this.partitioningColumns = Option(colNames)
validatePartitioningAndClustering()
this
}

/**
* Clusters the output by the given columns. If specified, the output is laid out such that
* records with similar values on the clustering column are grouped together in the same file.
*
* Clustering improves query efficiency by allowing queries with predicates on the clustering
* columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high
* cardinality columns.
*
* @since 4.0.0
*/
@scala.annotation.varargs
def clusterBy(colNames: String*): DataStreamWriter[T] = {
this.clusteringColumns = Option(colNames)
validatePartitioningAndClustering()
this
}

Expand Down Expand Up @@ -288,12 +307,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

if (!catalog.asTableCatalog.tableExists(identifier)) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val properties = normalizedClusteringCols.map { cols =>
Map(
DataSourceUtils.CLUSTERING_COLUMNS_KEY -> DataSourceUtils.encodePartitioningColumns(cols))
}.getOrElse(Map.empty)
val partitioningOrClusteringTransform = normalizedClusteringCols.map { colNames =>
Array(ClusterByTransform(colNames.map(col => FieldReference(col)))).toImmutableArraySeq
}.getOrElse(partitioningColumns.getOrElse(Nil).asTransforms.toImmutableArraySeq)

/**
* Note, currently the new table creation by this API doesn't fully cover the V2 table.
* TODO (SPARK-33638): Full support of v2 table creation
*/
val tableSpec = UnresolvedTableSpec(
Map.empty[String, String],
properties,
Some(source),
OptionList(Seq.empty),
extraOptions.get("path"),
Expand All @@ -303,7 +331,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
val cmd = CreateTable(
UnresolvedIdentifier(originalMultipartIdentifier),
df.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)),
partitioningColumns.getOrElse(Nil).asTransforms.toImmutableArraySeq,
partitioningOrClusteringTransform,
tableSpec,
ignoreIfExists = false)
Dataset.ofRows(df.sparkSession, cmd)
Expand Down Expand Up @@ -439,10 +467,22 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}

private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = {
// Do not allow the user to specify clustering columns in the options. Ignoring this option is
// consistent with the behavior of DataFrameWriter on non Path-based tables and with the
// behavior of DataStreamWriter on partitioning columns specified in options.
val optionsWithoutClusteringKey =
optionsWithPath.originalMap - DataSourceUtils.CLUSTERING_COLUMNS_KEY

val optionsWithClusteringColumns = normalizedClusteringCols match {
case Some(cols) => optionsWithoutClusteringKey + (
DataSourceUtils.CLUSTERING_COLUMNS_KEY ->
DataSourceUtils.encodePartitioningColumns(cols))
case None => optionsWithoutClusteringKey
}
val ds = DataSource(
df.sparkSession,
className = source,
options = optionsWithPath.originalMap,
options = optionsWithClusteringColumns,
partitionColumns = normalizedParCols.getOrElse(Nil))
ds.createSink(outputMode)
}
Expand Down Expand Up @@ -514,6 +554,10 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
cols.map(normalize(_, "Partition"))
}

private def normalizedClusteringCols: Option[Seq[String]] = clusteringColumns.map { cols =>
cols.map(normalize(_, "Clustering"))
}

/**
* The given column name may not be equal to any of the existing column names if we were in
* case-insensitive context. Normalize the given column name to the real one so that we don't
Expand All @@ -532,6 +576,13 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}
}

// Validate that partitionBy isn't used with clusterBy.
private def validatePartitioningAndClustering(): Unit = {
if (clusteringColumns.nonEmpty && partitioningColumns.nonEmpty) {
throw QueryCompilationErrors.clusterByWithPartitionedBy()
}
}

///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -554,6 +605,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
private var foreachBatchWriter: (Dataset[T], Long) => Unit = null

private var partitioningColumns: Option[Seq[String]] = None

private var clusteringColumns: Option[Seq[String]] = None
}

object DataStreamWriter {
Expand Down
Loading

0 comments on commit efc6a75

Please sign in to comment.