From ec1a8bcc9a2c010af0ca59afec69436785a75451 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Fri, 19 Jan 2024 17:47:23 +0100 Subject: [PATCH] Add widening type promotion from integers to decimals in Parquet vectorized reader --- .../parquet/ParquetVectorUpdaterFactory.java | 33 ++++- .../parquet/VectorizedColumnReader.java | 7 +- .../parquet/ParquetQuerySuite.scala | 8 +- .../parquet/ParquetTypeWideningSuite.scala | 123 +++++++++++++++--- 4 files changed, 144 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 3863818b02556..67cbc7a0973e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -1406,7 +1406,11 @@ private static class IntegerToDecimalUpdater extends DecimalUpdater { super(sparkType); LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); - this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale(); + if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) { + this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale(); + } else { + this.parquetScale = 0; + } } @Override @@ -1435,14 +1439,18 @@ public void decodeSingleDictionaryId( } } -private static class LongToDecimalUpdater extends DecimalUpdater { + private static class LongToDecimalUpdater extends DecimalUpdater { private final int parquetScale; - LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) { + LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) { super(sparkType); LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); - this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale(); + if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) { + this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale(); + } else { + this.parquetScale = 0; + } } @Override @@ -1651,6 +1659,20 @@ private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataTyp int scaleIncrease = requestedType.scale() - parquetType.getScale(); int precisionIncrease = requestedType.precision() - parquetType.getPrecision(); return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease; + } else if (typeAnnotation == null || typeAnnotation instanceof IntLogicalTypeAnnotation) { + // Allow reading integers (which may be un-annotated) as decimal as long as the requested + // decimal type is large enough to represent all possible values. + PrimitiveType.PrimitiveTypeName typeName = + descriptor.getPrimitiveType().getPrimitiveTypeName(); + int integerPrecision = requestedType.precision() - requestedType.scale(); + switch (typeName) { + case INT32: + return integerPrecision >= DecimalType$.MODULE$.IntDecimal().precision(); + case INT64: + return integerPrecision >= DecimalType$.MODULE$.LongDecimal().precision(); + default: + return false; + } } return false; } @@ -1661,6 +1683,9 @@ private static boolean isSameDecimalScale(ColumnDescriptor descriptor, DataType if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) { DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation; return decimalType.getScale() == d.scale(); + } else if (typeAnnotation == null || typeAnnotation instanceof IntLogicalTypeAnnotation) { + // Consider integers (which may be un-annotated) as having scale 0. + return d.scale() == 0; } return false; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 6479644968ed8..137387a35e955 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -151,10 +151,9 @@ private boolean isLazyDecodingSupported( // rebasing. switch (typeName) { case INT32: { - boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation; - boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation; + boolean isDecimal = sparkType instanceof DecimalType; boolean needsUpcast = sparkType == LongType || sparkType == DoubleType || - (isDate && sparkType == TimestampNTZType) || + sparkType == TimestampNTZType || (isDecimal && !DecimalType.is32BitDecimalType(sparkType)); boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && !"CORRECTED".equals(datetimeRebaseMode); @@ -162,7 +161,7 @@ private boolean isLazyDecodingSupported( break; } case INT64: { - boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation; + boolean isDecimal = sparkType instanceof DecimalType; boolean needsUpcast = (isDecimal && !DecimalType.is64BitDecimalType(sparkType)) || updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b306a526818e3..b8a6cb5d0712e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1037,8 +1037,10 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS withAllParquetReaders { // We can read the decimal parquet field with a larger precision, if scale is the same. - val schema = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)" - checkAnswer(readParquet(schema, path), df) + val schema1 = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)" + checkAnswer(readParquet(schema1, path), df) + val schema2 = "a DECIMAL(18, 1), b DECIMAL(38, 2), c DECIMAL(38, 2)" + checkAnswer(readParquet(schema2, path), df) } withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { @@ -1067,10 +1069,12 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00")) + checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00")) checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null)) checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 123456.0")) checkAnswer(readParquet("c DECIMAL(11, 1)", path), Row(null)) checkAnswer(readParquet("c DECIMAL(13, 0)", path), df.select("c")) + checkAnswer(readParquet("c DECIMAL(22, 0)", path), df.select("c")) val e = intercept[SparkException] { readParquet("d DECIMAL(3, 2)", path).collect() }.getCause diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala index 7b8357e207742..a305607193ee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.hadoop.fs.Path import org.apache.parquet.format.converter.ParquetMetadataConverter import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat} @@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.DecimalType.{ByteDecimal, IntDecimal, LongDecimal, ShortDecimal} class ParquetTypeWideningSuite extends QueryTest @@ -121,6 +123,19 @@ class ParquetTypeWideningSuite if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) { assertAllParquetFilesDictionaryEncoded(dir) } + + // Check which encoding was used when writing Parquet V2 files. + val isParquetV2 = spark.conf.getOption(ParquetOutputFormat.WRITER_VERSION) + .contains(ParquetProperties.WriterVersion.PARQUET_2_0.toString) + if (isParquetV2) { + if (dictionaryEnabled) { + assertParquetV2Encoding(dir, Encoding.PLAIN) + } else if (DecimalType.is64BitDecimalType(dataType)) { + assertParquetV2Encoding(dir, Encoding.DELTA_BINARY_PACKED) + } else if (DecimalType.isByteArrayDecimalType(dataType)) { + assertParquetV2Encoding(dir, Encoding.DELTA_BYTE_ARRAY) + } + } df } @@ -145,6 +160,27 @@ class ParquetTypeWideningSuite } } + /** + * Asserts that all parquet files in the given directory have all their columns encoded with the + * given encoding. + */ + private def assertParquetV2Encoding(dir: File, expected_encoding: Encoding): Unit = { + dir.listFiles(_.getName.endsWith(".parquet")).foreach { file => + val parquetMetadata = ParquetFileReader.readFooter( + spark.sessionState.newHadoopConf(), + new Path(dir.toString, file.getName), + ParquetMetadataConverter.NO_FILTER) + parquetMetadata.getBlocks.forEach { block => + block.getColumns.forEach { col => + assert( + col.getEncodings.contains(expected_encoding), + s"Expected column '${col.getPath.toDotString}' to use encoding $expected_encoding " + + s"but found ${col.getEncodings}.") + } + } + } + } + for { (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( (Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType), @@ -157,24 +193,77 @@ class ParquetTypeWideningSuite (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType) ) } - test(s"parquet widening conversion $fromType -> $toType") { - checkAllParquetReaders(values, fromType, toType, expectError = false) - } + test(s"parquet widening conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = false) + } + + for { + (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("1", Byte.MaxValue.toString), ByteType, IntDecimal), + (Seq("1", Byte.MaxValue.toString), ByteType, LongDecimal), + (Seq("1", Short.MaxValue.toString), ShortType, IntDecimal), + (Seq("1", Short.MaxValue.toString), ShortType, LongDecimal), + (Seq("1", Short.MaxValue.toString), ShortType, DecimalType(DecimalType.MAX_PRECISION, 0)), + (Seq("1", Int.MaxValue.toString), IntegerType, IntDecimal), + (Seq("1", Int.MaxValue.toString), IntegerType, LongDecimal), + (Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(DecimalType.MAX_PRECISION, 0)), + (Seq("1", Long.MaxValue.toString), LongType, LongDecimal), + (Seq("1", Long.MaxValue.toString), LongType, DecimalType(DecimalType.MAX_PRECISION, 0)), + (Seq("1", Byte.MaxValue.toString), ByteType, DecimalType(IntDecimal.precision + 1, 1)), + (Seq("1", Short.MaxValue.toString), ShortType, DecimalType(IntDecimal.precision + 1, 1)), + (Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(IntDecimal.precision + 1, 1)), + (Seq("1", Long.MaxValue.toString), LongType, DecimalType(LongDecimal.precision + 1, 1)) + ) + } + test(s"parquet widening conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = false) + } for { (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), (Seq("1.23", "10.34"), DoubleType, FloatType), (Seq("1.23", "10.34"), FloatType, LongType), + (Seq("1", "10"), LongType, DoubleType), (Seq("1", "10"), LongType, DateType), (Seq("1", "10"), IntegerType, TimestampType), (Seq("1", "10"), IntegerType, TimestampNTZType), (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType) ) } - test(s"unsupported parquet conversion $fromType -> $toType") { - checkAllParquetReaders(values, fromType, toType, expectError = true) - } + test(s"unsupported parquet conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = true) + } + + for { + (values: Seq[String], fromType: DataType, toType: DecimalType) <- Seq( + // Parquet stores byte, short, int values as INT32, which then requires using a decimal that + // can hold at least 4 byte integers. + (Seq("1", "2"), ByteType, DecimalType(1, 0)), + (Seq("1", "2"), ByteType, ByteDecimal), + (Seq("1", "2"), ShortType, ByteDecimal), + (Seq("1", "2"), ShortType, ShortDecimal), + (Seq("1", "2"), IntegerType, ShortDecimal), + (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision + 1, 1)), + (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision + 1, 1)), + (Seq("1", "2"), LongType, IntDecimal), + (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision - 1, 0)), + (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision - 1, 0)), + (Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision - 1, 0)), + (Seq("1", "2"), LongType, DecimalType(LongDecimal.precision - 1, 0)), + (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision, 1)), + (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision, 1)), + (Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision, 1)), + (Seq("1", "2"), LongType, DecimalType(LongDecimal.precision, 1)) + ) + } + test(s"unsupported parquet conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, + expectError = + // parquet-mr allows reading decimals into a smaller precision decimal type without + // checking for overflows. See test below checking for the overflow case in parquet-mr. + spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean) + } for { (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( @@ -201,17 +290,17 @@ class ParquetTypeWideningSuite Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++ Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20) } - test( - s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") { - checkAllParquetReaders( - values = Seq("1.23", "10.34"), - fromType = DecimalType(fromPrecision, 2), - toType = DecimalType(toPrecision, 2), - expectError = fromPrecision > toPrecision && - // parquet-mr allows reading decimals into a smaller precision decimal type without - // checking for overflows. See test below checking for the overflow case in parquet-mr. - spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean) - } + test( + s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") { + checkAllParquetReaders( + values = Seq("1.23", "10.34"), + fromType = DecimalType(fromPrecision, 2), + toType = DecimalType(toPrecision, 2), + expectError = fromPrecision > toPrecision && + // parquet-mr allows reading decimals into a smaller precision decimal type without + // checking for overflows. See test below checking for the overflow case in parquet-mr. + spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean) + } for { ((fromPrecision, fromScale), (toPrecision, toScale)) <-