diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index f5e5f729459f7..375d69034fd31 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -26,10 +26,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; +import java.util.*; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParser; @@ -43,24 +40,29 @@ * Build variant value and metadata by parsing JSON values. */ public class VariantBuilder { + public VariantBuilder(boolean allowDuplicateKeys) { + this.allowDuplicateKeys = allowDuplicateKeys; + } + /** * Parse a JSON string as a Variant value. * @throws VariantSizeLimitException if the resulting variant value or metadata would exceed * the SIZE_LIMIT (for example, this could be a maximum of 16 MiB). * @throws IOException if any JSON parsing error happens. */ - public static Variant parseJson(String json) throws IOException { + public static Variant parseJson(String json, boolean allowDuplicateKeys) throws IOException { try (JsonParser parser = new JsonFactory().createParser(json)) { parser.nextToken(); - return parseJson(parser); + return parseJson(parser, allowDuplicateKeys); } } /** - * Similar {@link #parseJson(String)}, but takes a JSON parser instead of string input. + * Similar {@link #parseJson(String, boolean)}, but takes a JSON parser instead of string input. */ - public static Variant parseJson(JsonParser parser) throws IOException { - VariantBuilder builder = new VariantBuilder(); + public static Variant parseJson(JsonParser parser, boolean allowDuplicateKeys) + throws IOException { + VariantBuilder builder = new VariantBuilder(allowDuplicateKeys); builder.buildJson(parser); return builder.result(); } @@ -274,23 +276,63 @@ public int getWritePos() { // record the offset of the field. The offset is computed as `getWritePos() - start`. // 3. The caller calls `finishWritingObject` to finish writing a variant object. // - // This function is responsible to sort the fields by key and check for any duplicate field keys. + // This function is responsible to sort the fields by key. If there are duplicate field keys: + // - when `allowDuplicateKeys` is true, the field with the greatest offset value (the last + // appended one) is kept. + // - otherwise, throw an exception. public void finishWritingObject(int start, ArrayList fields) { - int dataSize = writePos - start; int size = fields.size(); Collections.sort(fields); int maxId = size == 0 ? 0 : fields.get(0).id; - // Check for duplicate field keys. Only need to check adjacent key because they are sorted. - for (int i = 1; i < size; ++i) { - maxId = Math.max(maxId, fields.get(i).id); - String key = fields.get(i).key; - if (key.equals(fields.get(i - 1).key)) { - @SuppressWarnings("unchecked") - Map parameters = Map$.MODULE$.empty().updated("key", key); - throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters, - null, new QueryContext[]{}, ""); + if (allowDuplicateKeys) { + int distinctPos = 0; + // Maintain a list of distinct keys in-place. + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + if (fields.get(i).id == fields.get(i - 1).id) { + // Found a duplicate key. Keep the field with a greater offset. + if (fields.get(distinctPos).offset < fields.get(i).offset) { + fields.set(distinctPos, fields.get(distinctPos).withNewOffset(fields.get(i).offset)); + } + } else { + // Found a distinct key. Add the field to the list. + ++distinctPos; + fields.set(distinctPos, fields.get(i)); + } + } + if (distinctPos + 1 < fields.size()) { + size = distinctPos + 1; + // Resize `fields` to `size`. + fields.subList(size, fields.size()).clear(); + // Sort the fields by offsets so that we can move the value data of each field to the new + // offset without overwriting the fields after it. + fields.sort(Comparator.comparingInt(f -> f.offset)); + int currentOffset = 0; + for (int i = 0; i < size; ++i) { + int oldOffset = fields.get(i).offset; + int fieldSize = VariantUtil.valueSize(writeBuffer, start + oldOffset); + System.arraycopy(writeBuffer, start + oldOffset, + writeBuffer, start + currentOffset, fieldSize); + fields.set(i, fields.get(i).withNewOffset(currentOffset)); + currentOffset += fieldSize; + } + writePos = start + currentOffset; + // Change back to the sort order by field keys to meet the variant spec. + Collections.sort(fields); + } + } else { + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + String key = fields.get(i).key; + if (key.equals(fields.get(i - 1).key)) { + @SuppressWarnings("unchecked") + Map parameters = Map$.MODULE$.empty().updated("key", key); + throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters, + null, new QueryContext[]{}, ""); + } } } + int dataSize = writePos - start; boolean largeSize = size > U8_MAX; int sizeBytes = largeSize ? U32_SIZE : 1; int idSize = getIntegerSize(maxId); @@ -415,6 +457,10 @@ public FieldEntry(String key, int id, int offset) { this.offset = offset; } + FieldEntry withNewOffset(int newOffset) { + return new FieldEntry(key, id, newOffset); + } + @Override public int compareTo(FieldEntry other) { return key.compareTo(other.key); @@ -518,4 +564,5 @@ private boolean tryParseDecimal(String input) { private final HashMap dictionary = new HashMap<>(); // Store all keys in `dictionary` in the order of id. private final ArrayList dictionaryKeys = new ArrayList<>(); + private final boolean allowDuplicateKeys; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index 7682e4a048ecb..bbf554d384b12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -31,7 +31,10 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} */ object VariantExpressionEvalUtils { - def parseJson(input: UTF8String, failOnError: Boolean = true): VariantVal = { + def parseJson( + input: UTF8String, + allowDuplicateKeys: Boolean = false, + failOnError: Boolean = true): VariantVal = { def parseJsonFailure(exception: Throwable): VariantVal = { if (failOnError) { throw exception @@ -40,7 +43,7 @@ object VariantExpressionEvalUtils { } } try { - val v = VariantBuilder.parseJson(input.toString) + val v = VariantBuilder.parseJson(input.toString, allowDuplicateKeys) new VariantVal(v.getValue, v.getMetadata) } catch { case _: VariantSizeLimitException => @@ -69,7 +72,9 @@ object VariantExpressionEvalUtils { /** Cast a Spark value from `dataType` into the variant type. */ def castToVariant(input: Any, dataType: DataType): VariantVal = { - val builder = new VariantBuilder + // Enforce strict check because it is illegal for input struct/map/variant to contain duplicate + // keys. + val builder = new VariantBuilder(false) buildVariant(builder, input, dataType) val v = builder.result() new VariantVal(v.getValue, v.getMetadata) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index de7544381381a..bd956fa5c00e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -59,8 +59,11 @@ case class ParseJson(child: Expression, failOnError: Boolean = true) VariantExpressionEvalUtils.getClass, VariantType, "parseJson", - Seq(child, Literal(failOnError, BooleanType)), - inputTypes :+ BooleanType, + Seq( + child, + Literal(SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS), BooleanType), + Literal(failOnError, BooleanType)), + inputTypes :+ BooleanType :+ BooleanType, returnNullable = !failOnError) override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil @@ -324,7 +327,7 @@ case object VariantGet { if (dataType == VariantType) { // Build a new variant, in order to strip off any unnecessary metadata. - val builder = new VariantBuilder + val builder = new VariantBuilder(false) builder.appendVariant(v) val result = builder.result() return new VariantVal(result.getValue, result.getMetadata) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index d10e4304c86fd..26de4cc7ad1c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -117,6 +117,8 @@ class JacksonParser( } } + private val variantAllowDuplicateKeys = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS) + protected final def parseVariant(parser: JsonParser): VariantVal = { // Skips `FIELD_NAME` at the beginning. This check is adapted from `parseJsonToken`, but we // cannot directly use the function here because it also handles the `VALUE_NULL` token and @@ -125,7 +127,7 @@ class JacksonParser( parser.nextToken() } try { - val v = VariantBuilder.parseJson(parser) + val v = VariantBuilder.parseJson(parser, variantAllowDuplicateKeys) new VariantVal(v.getValue, v.getMetadata) } catch { case _: VariantSizeLimitException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 10cd057d9562c..72915f0e5c256 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4412,6 +4412,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val VARIANT_ALLOW_DUPLICATE_KEYS = + buildConf("spark.sql.variant.allowDuplicateKeys") + .internal() + .doc("When set to false, parsing variant from JSON will throw an error if there are " + + "duplicate keys in the input JSON object. When set to true, the parser will keep the " + + "last occurrence of all fields with the same key.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK = buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala index 8fc72caa47860..2b8c64a1af679 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala @@ -24,12 +24,14 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} class VariantExpressionEvalUtilsSuite extends SparkFunSuite { test("parseJson type coercion") { - def check(json: String, expectedValue: Array[Byte], expectedMetadata: Array[Byte]): Unit = { + def check(json: String, expectedValue: Array[Byte], expectedMetadata: Array[Byte], + allowDuplicateKeys: Boolean = false): Unit = { // parse_json - val actual = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json)) + val actual = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json), + allowDuplicateKeys) // try_parse_json val tryActual = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json), - failOnError = false) + allowDuplicateKeys, failOnError = false) val expected = new VariantVal(expectedValue, expectedMetadata) assert(actual === expected && tryActual === expected) } @@ -89,6 +91,13 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite { /* offset list */ 0, 2, 4, 6, /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2, shortStrHeader(1), '3'), Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c')) + check("""{"a": 1, "b": 2, "c": "3", "a": 4}""", Array(objectHeader(false, 1, 1), + /* size */ 3, + /* id list */ 0, 1, 2, + /* offset list */ 4, 0, 2, 6, + /* field data */ primitiveHeader(INT1), 2, shortStrHeader(1), '3', primitiveHeader(INT1), 4), + Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'), + allowDuplicateKeys = true) check("""{"z": 1, "y": 2, "x": "3"}""", Array(objectHeader(false, 1, 1), /* size */ 3, /* id list */ 2, 1, 0, @@ -109,10 +118,11 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite { test("parseJson negative") { def checkException(json: String, errorClass: String, parameters: Map[String, String]): Unit = { val try_parse_json_output = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json), - failOnError = false) + allowDuplicateKeys = false, failOnError = false) checkError( exception = intercept[SparkThrowable] { - VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json)) + VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json), + allowDuplicateKeys = false) }, errorClass = errorClass, parameters = parameters diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain index e750021ce22bb..988447c2d6418 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_is_variant_null.explain @@ -1,2 +1,2 @@ -Project [static_invoke(VariantExpressionEvalUtils.isVariantNull(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)))) AS is_variant_null(parse_json(g))#0] +Project [static_invoke(VariantExpressionEvalUtils.isVariantNull(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)))) AS is_variant_null(parse_json(g))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain index cbcf803b39010..a40f89c03c888 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_json.explain @@ -1,2 +1,2 @@ -Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)) AS parse_json(g)#0] +Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)) AS parse_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain index 04b33fdd70678..c82a10655c332 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant.explain @@ -1,2 +1,2 @@ -Project [static_invoke(SchemaOfVariant.schemaOfVariant(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)))) AS schema_of_variant(parse_json(g))#0] +Project [static_invoke(SchemaOfVariant.schemaOfVariant(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)))) AS schema_of_variant(parse_json(g))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain index 18e8801bb2986..3d894628ab7e0 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_variant_agg.explain @@ -1,2 +1,2 @@ -Aggregate [schema_of_variant_agg(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)), 0, 0) AS schema_of_variant_agg(parse_json(g))#0] +Aggregate [schema_of_variant_agg(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)), 0, 0) AS schema_of_variant_agg(parse_json(g))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain index 826ec4fc81d83..7a0c0078128f5 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_json.explain @@ -1,2 +1,2 @@ -Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false)) AS try_parse_json(g)#0] +Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, false)) AS try_parse_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain index 933fbff8e1f3d..65f527da78c4a 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_variant_get.explain @@ -1,2 +1,2 @@ -Project [try_variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)), $, IntegerType, false, Some(America/Los_Angeles)) AS try_variant_get(parse_json(g), $)#0] +Project [try_variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)), $, IntegerType, false, Some(America/Los_Angeles)) AS try_variant_get(parse_json(g), $)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain index 2e0baf058f72a..33c6b3a52f529 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_get.explain @@ -1,2 +1,2 @@ -Project [variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)), $, IntegerType, true, Some(America/Los_Angeles)) AS variant_get(parse_json(g), $)#0] +Project [variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)), $, IntegerType, true, Some(America/Los_Angeles)) AS variant_get(parse_json(g), $)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index b5e22e71a8fa0..4a20ec4af7e65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray @@ -57,6 +58,12 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { ) // scalastyle:on nonascii check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]") + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") { + check( + """{"c": [], "b": 0, "a": null, "a": {"x": 0, "x": 1}, "b": 1, "b": 2, "c": [3]}""", + """{"a":{"x":1},"b":2,"c":[3]}""" + ) + } } test("from_json/to_json round-trip") { @@ -146,7 +153,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { val variantDF = df.select(parse_json(col("v"))) val plan = variantDF.queryExecution.executedPlan assert(plan.isInstanceOf[WholeStageCodegenExec]) - val v = VariantBuilder.parseJson("""{"a":1}""") + val v = VariantBuilder.parseJson("""{"a":1}""", false) val expected = new VariantVal(v.getValue, v.getMetadata) checkAnswer(variantDF, Seq(Row(expected))) }