Skip to content

Commit

Permalink
[SPARK-49112][CONNECT][TEST] Make createLocalRelationProto support …
Browse files Browse the repository at this point in the history
…`TimestampType`

### What changes were proposed in this pull request?
Make `createLocalRelationProto` support relation with `TimestampType`

### Why are the changes needed?
existing helper function `createLocalRelationProto` cannot create table with `TimestampType`:

```
  org.apache.spark.SparkException: [INTERNAL_ERROR] Missing timezoneId where it is mandatory. SQLSTATE: XX000
  at org.apache.spark.SparkException$.internalError(SparkException.scala:99)
  at org.apache.spark.SparkException$.internalError(SparkException.scala:103)
  at org.apache.spark.sql.util.ArrowUtils$.toArrowType(ArrowUtils.scala:57)
  at org.apache.spark.sql.util.ArrowUtils$.toArrowField(ArrowUtils.scala:139)
  at org.apache.spark.sql.util.ArrowUtils$.$anonfun$toArrowSchema$1(ArrowUtils.scala:181)
  at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
  at scala.collection.Iterator.foreach(Iterator.scala:943)
  at scala.collection.Iterator.foreach$(Iterator.scala:943)
  at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
  at scala.collection.IterableLike.foreach(IterableLike.scala:74)
```

### Does this PR introduce _any_ user-facing change?
No, test-only

### How was this patch tested?
added ut

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#47608 from zhengruifeng/create_timestamp_localrel.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Aug 6, 2024
1 parent de8ee94 commit 6d32472
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ trait SparkConnectPlanTest extends SharedSparkSession {
*/
def createLocalRelationProto(
attrs: Seq[AttributeReference],
data: Seq[InternalRow]): proto.Relation = {
data: Seq[InternalRow],
timeZoneId: String = "UTC"): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()

val bytes = ArrowConverters
Expand All @@ -96,7 +97,7 @@ trait SparkConnectPlanTest extends SharedSparkSession {
DataTypeUtils.fromAttributes(attrs.map(_.toAttribute)),
Long.MaxValue,
Long.MaxValue,
null,
timeZoneId,
true)
.next()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog,
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

Expand All @@ -64,6 +64,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()),
Seq.empty)

lazy val connectTestRelation3 =
createLocalRelationProto(
Seq(AttributeReference("id", IntegerType)(), AttributeReference("date", TimestampType)()),
Seq.empty)

lazy val connectTestRelationMap =
createLocalRelationProto(
Seq(AttributeReference("id", MapType(StringType, StringType))()),
Expand All @@ -79,6 +84,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
new java.util.ArrayList[Row](),
StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))))

lazy val sparkTestRelation3: DataFrame =
spark.createDataFrame(
new java.util.ArrayList[Row](),
StructType(Seq(StructField("id", IntegerType), StructField("date", TimestampType))))

lazy val sparkTestRelationMap: DataFrame =
spark.createDataFrame(
new java.util.ArrayList[Row](),
Expand All @@ -93,6 +103,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan, sparkPlan)
}

test("Basic select timestamp") {
val connectPlan = connectTestRelation3.select("date".protoAttr)
val sparkPlan = sparkTestRelation3.select("date")
comparePlans(connectPlan, sparkPlan)
}

test("Test select expression in strings") {
val connectPlan = connectTestRelation.selectExpr("abs(id)", "name")
val sparkPlan = sparkTestRelation.selectExpr("abs(id)", "name")
Expand Down

0 comments on commit 6d32472

Please sign in to comment.