diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 86c10b834f31d..e44d3eacc66df 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -87,7 +87,8 @@ trait SparkConnectPlanTest extends SharedSparkSession { */ def createLocalRelationProto( attrs: Seq[AttributeReference], - data: Seq[InternalRow]): proto.Relation = { + data: Seq[InternalRow], + timeZoneId: String = "UTC"): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() val bytes = ArrowConverters @@ -96,7 +97,7 @@ trait SparkConnectPlanTest extends SharedSparkSession { DataTypeUtils.fromAttributes(attrs.map(_.toAttribute)), Long.MaxValue, Long.MaxValue, - null, + timeZoneId, true) .next() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 190f8cde16f50..03bf5a4c10dbc 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -64,6 +64,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()), Seq.empty) + lazy val connectTestRelation3 = + createLocalRelationProto( + Seq(AttributeReference("id", IntegerType)(), AttributeReference("date", TimestampType)()), + Seq.empty) + lazy val connectTestRelationMap = createLocalRelationProto( Seq(AttributeReference("id", MapType(StringType, StringType))()), @@ -79,6 +84,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { new java.util.ArrayList[Row](), StructType(Seq(StructField("id", IntegerType), StructField("name", StringType)))) + lazy val sparkTestRelation3: DataFrame = + spark.createDataFrame( + new java.util.ArrayList[Row](), + StructType(Seq(StructField("id", IntegerType), StructField("date", TimestampType)))) + lazy val sparkTestRelationMap: DataFrame = spark.createDataFrame( new java.util.ArrayList[Row](), @@ -93,6 +103,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan, sparkPlan) } + test("Basic select timestamp") { + val connectPlan = connectTestRelation3.select("date".protoAttr) + val sparkPlan = sparkTestRelation3.select("date") + comparePlans(connectPlan, sparkPlan) + } + test("Test select expression in strings") { val connectPlan = connectTestRelation.selectExpr("abs(id)", "name") val sparkPlan = sparkTestRelation.selectExpr("abs(id)", "name")