Skip to content

Commit

Permalink
[SPARK-48910][SQL] Use HashSet/HashMap to avoid linear searches in Pr…
Browse files Browse the repository at this point in the history
…eprocessTableCreation

### What changes were proposed in this pull request?

Use `HashSet`/`HashMap` instead of doing linear searches over the `Seq`. In case of 1000s of partitions this significantly improves the performance.

### Why are the changes needed?

To avoid the O(n*m) passes in the `PreprocessTableCreation`

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing UTs

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#47484 from vladimirg-db/vladimirg-db/get-rid-of-linear-searches-preprocess-table-creation.

Authored-by: Vladimir Golubev <vladimir.golubev@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
vladimirg-db authored and cloud-fan committed Jul 29, 2024
1 parent efc6a75 commit db03653
Showing 1 changed file with 15 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources

import java.util.Locale

import scala.collection.mutable.{HashMap, HashSet}
import scala.jdk.CollectionConverters._

import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
Expand Down Expand Up @@ -248,10 +249,14 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical
DDLUtils.checkTableColumns(tableDesc.copy(schema = analyzedQuery.schema))

val output = analyzedQuery.output

val outputByName = HashMap(output.map(o => o.name -> o): _*)
val partitionAttrs = normalizedTable.partitionColumnNames.map { partCol =>
output.find(_.name == partCol).get
outputByName(partCol)
}
val newOutput = output.filterNot(partitionAttrs.contains) ++ partitionAttrs
val partitionAttrsSet = HashSet(partitionAttrs: _*)
val newOutput = output.filterNot(partitionAttrsSet.contains) ++ partitionAttrs

val reorderedQuery = if (newOutput == output) {
analyzedQuery
} else {
Expand All @@ -263,12 +268,14 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical
DDLUtils.checkTableColumns(tableDesc)
val normalizedTable = normalizeCatalogTable(tableDesc.schema, tableDesc)

val normalizedSchemaByName = HashMap(normalizedTable.schema.map(s => s.name -> s): _*)
val partitionSchema = normalizedTable.partitionColumnNames.map { partCol =>
normalizedTable.schema.find(_.name == partCol).get
normalizedSchemaByName(partCol)
}

val reorderedSchema =
StructType(normalizedTable.schema.filterNot(partitionSchema.contains) ++ partitionSchema)
val partitionSchemaSet = HashSet(partitionSchema: _*)
val reorderedSchema = StructType(
normalizedTable.schema.filterNot(partitionSchemaSet.contains) ++ partitionSchema
)

c.copy(tableDesc = normalizedTable.copy(schema = reorderedSchema))
}
Expand Down Expand Up @@ -360,8 +367,9 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical
messageParameters = Map.empty)
}

val normalizedPartitionColsSet = HashSet(normalizedPartitionCols: _*)
schema
.filter(f => normalizedPartitionCols.contains(f.name))
.filter(f => normalizedPartitionColsSet.contains(f.name))
.foreach { field =>
if (!PartitioningUtils.canPartitionOn(field.dataType)) {
throw QueryCompilationErrors.invalidPartitionColumnDataTypeError(field)
Expand Down

0 comments on commit db03653

Please sign in to comment.