diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 10e2718da8333..94e893d468b39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -1009,10 +1009,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] append(str) append("\n") - if (innerChildren.nonEmpty) { + val innerChildrenLocal = innerChildren + if (innerChildrenLocal.nonEmpty) { lastChildren.add(children.isEmpty) lastChildren.add(false) - innerChildren.init.foreach(_.generateTreeString( + innerChildrenLocal.init.foreach(_.generateTreeString( depth + 2, lastChildren, append, verbose, addSuffix = addSuffix, maxFields = maxFields, printNodeId = printNodeId, indent = indent)) lastChildren.remove(lastChildren.size() - 1) @@ -1020,7 +1021,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] lastChildren.add(children.isEmpty) lastChildren.add(true) - innerChildren.last.generateTreeString( + innerChildrenLocal.last.generateTreeString( depth + 2, lastChildren, append, verbose, addSuffix = addSuffix, maxFields = maxFields, printNodeId = printNodeId, indent = indent) lastChildren.remove(lastChildren.size() - 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala index 3da3e646f36b0..11f6ae0e47ee1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala @@ -75,8 +75,12 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { * Given a input physical plan, performs the following tasks. * 1. Generates the explain output for the input plan excluding the subquery plans. * 2. Generates the explain output for each subquery referenced in the plan. + * + * Note that, ideally this is a no-op as different explain actions operate on different plan, + * instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared + * plan instance across multi-queries. Add lock for this method to avoid tag race condition. */ - def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = { + def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = synchronized { try { // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow // intentional overwriting of IDs generated in previous AQE iteration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index b194a55ae3596..f76b702ca6714 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -403,17 +403,7 @@ case class InMemoryRelation( @volatile var statsOfPlanToCache: Statistics = null - - override lazy val innerChildren: Seq[SparkPlan] = { - // The cachedPlan needs to be cloned here because it does not get cloned when SparkPlan.clone is - // called. This is a problem because when the explain output is generated for - // a plan it traverses the innerChildren and modifies their TreeNode.tags. If the plan is not - // cloned, there is a thread safety issue in the case that two plans with a shared cache - // operator have explain called at the same time. The cachedPlan cannot be cloned because - // it contains stateful information so we only clone it for the purpose of generating the - // explain output. - Seq(cachedPlan.clone()) - } + override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) override def doCanonicalize(): logical.LogicalPlan = copy(output = output.map(QueryPlan.normalizeExpressions(_, output)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala index a5c5ec40af6fe..2c73622739a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala @@ -18,22 +18,14 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.test.SharedSparkSessionBase import org.apache.spark.storage.StorageLevel -class InMemoryRelationSuite extends SparkFunSuite with SharedSparkSessionBase { - test("SPARK-43157: Clone innerChildren cached plan") { - val d = spark.range(1) - val relation = InMemoryRelation(StorageLevel.MEMORY_ONLY, d.queryExecution, None) - val cloned = relation.clone().asInstanceOf[InMemoryRelation] - - val relationCachedPlan = relation.innerChildren.head - val clonedCachedPlan = cloned.innerChildren.head - - // verify the plans are not the same object but are logically equivalent - assert(!relationCachedPlan.eq(clonedCachedPlan)) - assert(relationCachedPlan === clonedCachedPlan) - } +class InMemoryRelationSuite extends SparkFunSuite + with SharedSparkSessionBase with AdaptiveSparkPlanHelper { test("SPARK-46779: InMemoryRelations with the same cached plan are semantically equivalent") { val d = spark.range(1) @@ -41,4 +33,27 @@ class InMemoryRelationSuite extends SparkFunSuite with SharedSparkSessionBase { val r2 = r1.withOutput(r1.output.map(_.newInstance())) assert(r1.sameResult(r2)) } + + test("SPARK-47177: Cached SQL plan do not display final AQE plan in explain string") { + def findIMRInnerChild(p: SparkPlan): SparkPlan = { + val tableCache = find(p) { + case _: InMemoryTableScanExec => true + case _ => false + } + assert(tableCache.isDefined) + tableCache.get.asInstanceOf[InMemoryTableScanExec].relation.innerChildren.head + } + + val d1 = spark.range(1).withColumn("key", expr("id % 100")) + .groupBy("key").agg(Map("key" -> "count")) + val cached_d2 = d1.cache() + val df = cached_d2.withColumn("key2", expr("key % 10")) + .groupBy("key2").agg(Map("key2" -> "count")) + + assert(findIMRInnerChild(df.queryExecution.executedPlan).treeString + .contains("AdaptiveSparkPlan isFinalPlan=false")) + df.collect() + assert(findIMRInnerChild(df.queryExecution.executedPlan).treeString + .contains("AdaptiveSparkPlan isFinalPlan=true")) + } }