Skip to content

Commit

Permalink
[SPARK-49451] Allow duplicate keys in parse_json
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Before the change, `parse_json` will throw an error if there are duplicate keys in an input JSON object. After the change, `parse_json` will keep the last field with the same key. It doesn't affect other variant building expressions (creating a variant from struct/map/variant) because it is legal for them to contain duplicate keys.

The change is guarded by a flag and disabled by default.

### Why are the changes needed?

To make the data migration simpler. The user won't need to change its data if it contains duplicated keys. The behavior is inspired by https://docs.aws.amazon.com/redshift/latest/dg/super-configurations.html#parsing-options-super (reject duplicate keys or keep the last occurance).

### Does this PR introduce _any_ user-facing change?

Yes, as described in the first section.

### How was this patch tested?

New unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47920 from chenhao-db/allow_duplicate_keys.

Authored-by: Chenhao Li <chenhao.li@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
chenhao-db authored and cloud-fan committed Sep 2, 2024
1 parent ec7570e commit 8879df5
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.*;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonParser;
Expand All @@ -43,24 +40,29 @@
* Build variant value and metadata by parsing JSON values.
*/
public class VariantBuilder {
public VariantBuilder(boolean allowDuplicateKeys) {
this.allowDuplicateKeys = allowDuplicateKeys;
}

/**
* Parse a JSON string as a Variant value.
* @throws VariantSizeLimitException if the resulting variant value or metadata would exceed
* the SIZE_LIMIT (for example, this could be a maximum of 16 MiB).
* @throws IOException if any JSON parsing error happens.
*/
public static Variant parseJson(String json) throws IOException {
public static Variant parseJson(String json, boolean allowDuplicateKeys) throws IOException {
try (JsonParser parser = new JsonFactory().createParser(json)) {
parser.nextToken();
return parseJson(parser);
return parseJson(parser, allowDuplicateKeys);
}
}

/**
* Similar {@link #parseJson(String)}, but takes a JSON parser instead of string input.
* Similar {@link #parseJson(String, boolean)}, but takes a JSON parser instead of string input.
*/
public static Variant parseJson(JsonParser parser) throws IOException {
VariantBuilder builder = new VariantBuilder();
public static Variant parseJson(JsonParser parser, boolean allowDuplicateKeys)
throws IOException {
VariantBuilder builder = new VariantBuilder(allowDuplicateKeys);
builder.buildJson(parser);
return builder.result();
}
Expand Down Expand Up @@ -274,23 +276,63 @@ public int getWritePos() {
// record the offset of the field. The offset is computed as `getWritePos() - start`.
// 3. The caller calls `finishWritingObject` to finish writing a variant object.
//
// This function is responsible to sort the fields by key and check for any duplicate field keys.
// This function is responsible to sort the fields by key. If there are duplicate field keys:
// - when `allowDuplicateKeys` is true, the field with the greatest offset value (the last
// appended one) is kept.
// - otherwise, throw an exception.
public void finishWritingObject(int start, ArrayList<FieldEntry> fields) {
int dataSize = writePos - start;
int size = fields.size();
Collections.sort(fields);
int maxId = size == 0 ? 0 : fields.get(0).id;
// Check for duplicate field keys. Only need to check adjacent key because they are sorted.
for (int i = 1; i < size; ++i) {
maxId = Math.max(maxId, fields.get(i).id);
String key = fields.get(i).key;
if (key.equals(fields.get(i - 1).key)) {
@SuppressWarnings("unchecked")
Map<String, String> parameters = Map$.MODULE$.<String, String>empty().updated("key", key);
throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters,
null, new QueryContext[]{}, "");
if (allowDuplicateKeys) {
int distinctPos = 0;
// Maintain a list of distinct keys in-place.
for (int i = 1; i < size; ++i) {
maxId = Math.max(maxId, fields.get(i).id);
if (fields.get(i).id == fields.get(i - 1).id) {
// Found a duplicate key. Keep the field with a greater offset.
if (fields.get(distinctPos).offset < fields.get(i).offset) {
fields.set(distinctPos, fields.get(distinctPos).withNewOffset(fields.get(i).offset));
}
} else {
// Found a distinct key. Add the field to the list.
++distinctPos;
fields.set(distinctPos, fields.get(i));
}
}
if (distinctPos + 1 < fields.size()) {
size = distinctPos + 1;
// Resize `fields` to `size`.
fields.subList(size, fields.size()).clear();
// Sort the fields by offsets so that we can move the value data of each field to the new
// offset without overwriting the fields after it.
fields.sort(Comparator.comparingInt(f -> f.offset));
int currentOffset = 0;
for (int i = 0; i < size; ++i) {
int oldOffset = fields.get(i).offset;
int fieldSize = VariantUtil.valueSize(writeBuffer, start + oldOffset);
System.arraycopy(writeBuffer, start + oldOffset,
writeBuffer, start + currentOffset, fieldSize);
fields.set(i, fields.get(i).withNewOffset(currentOffset));
currentOffset += fieldSize;
}
writePos = start + currentOffset;
// Change back to the sort order by field keys to meet the variant spec.
Collections.sort(fields);
}
} else {
for (int i = 1; i < size; ++i) {
maxId = Math.max(maxId, fields.get(i).id);
String key = fields.get(i).key;
if (key.equals(fields.get(i - 1).key)) {
@SuppressWarnings("unchecked")
Map<String, String> parameters = Map$.MODULE$.<String, String>empty().updated("key", key);
throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters,
null, new QueryContext[]{}, "");
}
}
}
int dataSize = writePos - start;
boolean largeSize = size > U8_MAX;
int sizeBytes = largeSize ? U32_SIZE : 1;
int idSize = getIntegerSize(maxId);
Expand Down Expand Up @@ -415,6 +457,10 @@ public FieldEntry(String key, int id, int offset) {
this.offset = offset;
}

FieldEntry withNewOffset(int newOffset) {
return new FieldEntry(key, id, newOffset);
}

@Override
public int compareTo(FieldEntry other) {
return key.compareTo(other.key);
Expand Down Expand Up @@ -518,4 +564,5 @@ private boolean tryParseDecimal(String input) {
private final HashMap<String, Integer> dictionary = new HashMap<>();
// Store all keys in `dictionary` in the order of id.
private final ArrayList<byte[]> dictionaryKeys = new ArrayList<>();
private final boolean allowDuplicateKeys;
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
*/
object VariantExpressionEvalUtils {

def parseJson(input: UTF8String, failOnError: Boolean = true): VariantVal = {
def parseJson(
input: UTF8String,
allowDuplicateKeys: Boolean = false,
failOnError: Boolean = true): VariantVal = {
def parseJsonFailure(exception: Throwable): VariantVal = {
if (failOnError) {
throw exception
Expand All @@ -40,7 +43,7 @@ object VariantExpressionEvalUtils {
}
}
try {
val v = VariantBuilder.parseJson(input.toString)
val v = VariantBuilder.parseJson(input.toString, allowDuplicateKeys)
new VariantVal(v.getValue, v.getMetadata)
} catch {
case _: VariantSizeLimitException =>
Expand Down Expand Up @@ -69,7 +72,9 @@ object VariantExpressionEvalUtils {

/** Cast a Spark value from `dataType` into the variant type. */
def castToVariant(input: Any, dataType: DataType): VariantVal = {
val builder = new VariantBuilder
// Enforce strict check because it is illegal for input struct/map/variant to contain duplicate
// keys.
val builder = new VariantBuilder(false)
buildVariant(builder, input, dataType)
val v = builder.result()
new VariantVal(v.getValue, v.getMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ case class ParseJson(child: Expression, failOnError: Boolean = true)
VariantExpressionEvalUtils.getClass,
VariantType,
"parseJson",
Seq(child, Literal(failOnError, BooleanType)),
inputTypes :+ BooleanType,
Seq(
child,
Literal(SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS), BooleanType),
Literal(failOnError, BooleanType)),
inputTypes :+ BooleanType :+ BooleanType,
returnNullable = !failOnError)

override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil
Expand Down Expand Up @@ -324,7 +327,7 @@ case object VariantGet {

if (dataType == VariantType) {
// Build a new variant, in order to strip off any unnecessary metadata.
val builder = new VariantBuilder
val builder = new VariantBuilder(false)
builder.appendVariant(v)
val result = builder.result()
return new VariantVal(result.getValue, result.getMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class JacksonParser(
}
}

private val variantAllowDuplicateKeys = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)

protected final def parseVariant(parser: JsonParser): VariantVal = {
// Skips `FIELD_NAME` at the beginning. This check is adapted from `parseJsonToken`, but we
// cannot directly use the function here because it also handles the `VALUE_NULL` token and
Expand All @@ -125,7 +127,7 @@ class JacksonParser(
parser.nextToken()
}
try {
val v = VariantBuilder.parseJson(parser)
val v = VariantBuilder.parseJson(parser, variantAllowDuplicateKeys)
new VariantVal(v.getValue, v.getMetadata)
} catch {
case _: VariantSizeLimitException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4412,6 +4412,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val VARIANT_ALLOW_DUPLICATE_KEYS =
buildConf("spark.sql.variant.allowDuplicateKeys")
.internal()
.doc("When set to false, parsing variant from JSON will throw an error if there are " +
"duplicate keys in the input JSON object. When set to true, the parser will keep the " +
"last occurrence of all fields with the same key.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK =
buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
class VariantExpressionEvalUtilsSuite extends SparkFunSuite {

test("parseJson type coercion") {
def check(json: String, expectedValue: Array[Byte], expectedMetadata: Array[Byte]): Unit = {
def check(json: String, expectedValue: Array[Byte], expectedMetadata: Array[Byte],
allowDuplicateKeys: Boolean = false): Unit = {
// parse_json
val actual = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json))
val actual = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json),
allowDuplicateKeys)
// try_parse_json
val tryActual = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json),
failOnError = false)
allowDuplicateKeys, failOnError = false)
val expected = new VariantVal(expectedValue, expectedMetadata)
assert(actual === expected && tryActual === expected)
}
Expand Down Expand Up @@ -89,6 +91,13 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite {
/* offset list */ 0, 2, 4, 6,
/* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2, shortStrHeader(1), '3'),
Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'))
check("""{"a": 1, "b": 2, "c": "3", "a": 4}""", Array(objectHeader(false, 1, 1),
/* size */ 3,
/* id list */ 0, 1, 2,
/* offset list */ 4, 0, 2, 6,
/* field data */ primitiveHeader(INT1), 2, shortStrHeader(1), '3', primitiveHeader(INT1), 4),
Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'),
allowDuplicateKeys = true)
check("""{"z": 1, "y": 2, "x": "3"}""", Array(objectHeader(false, 1, 1),
/* size */ 3,
/* id list */ 2, 1, 0,
Expand All @@ -109,10 +118,11 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite {
test("parseJson negative") {
def checkException(json: String, errorClass: String, parameters: Map[String, String]): Unit = {
val try_parse_json_output = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json),
failOnError = false)
allowDuplicateKeys = false, failOnError = false)
checkError(
exception = intercept[SparkThrowable] {
VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json))
VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json),
allowDuplicateKeys = false)
},
errorClass = errorClass,
parameters = parameters
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [static_invoke(VariantExpressionEvalUtils.isVariantNull(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)))) AS is_variant_null(parse_json(g))#0]
Project [static_invoke(VariantExpressionEvalUtils.isVariantNull(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)))) AS is_variant_null(parse_json(g))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)) AS parse_json(g)#0]
Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)) AS parse_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [static_invoke(SchemaOfVariant.schemaOfVariant(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)))) AS schema_of_variant(parse_json(g))#0]
Project [static_invoke(SchemaOfVariant.schemaOfVariant(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)))) AS schema_of_variant(parse_json(g))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Aggregate [schema_of_variant_agg(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)), 0, 0) AS schema_of_variant_agg(parse_json(g))#0]
Aggregate [schema_of_variant_agg(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)), 0, 0) AS schema_of_variant_agg(parse_json(g))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false)) AS try_parse_json(g)#0]
Project [static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, false)) AS try_parse_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [try_variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)), $, IntegerType, false, Some(America/Los_Angeles)) AS try_variant_get(parse_json(g), $)#0]
Project [try_variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)), $, IntegerType, false, Some(America/Los_Angeles)) AS try_variant_get(parse_json(g), $)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, true)), $, IntegerType, true, Some(America/Los_Angeles)) AS variant_get(parse_json(g), $)#0]
Project [variant_get(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true)), $, IntegerType, true, Some(America/Los_Angeles)) AS variant_get(parse_json(g), $)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarArray
Expand Down Expand Up @@ -57,6 +58,12 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
)
// scalastyle:on nonascii
check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") {
check(
"""{"c": [], "b": 0, "a": null, "a": {"x": 0, "x": 1}, "b": 1, "b": 2, "c": [3]}""",
"""{"a":{"x":1},"b":2,"c":[3]}"""
)
}
}

test("from_json/to_json round-trip") {
Expand Down Expand Up @@ -146,7 +153,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
val variantDF = df.select(parse_json(col("v")))
val plan = variantDF.queryExecution.executedPlan
assert(plan.isInstanceOf[WholeStageCodegenExec])
val v = VariantBuilder.parseJson("""{"a":1}""")
val v = VariantBuilder.parseJson("""{"a":1}""", false)
val expected = new VariantVal(v.getValue, v.getMetadata)
checkAnswer(variantDF, Seq(Row(expected)))
}
Expand Down

0 comments on commit 8879df5

Please sign in to comment.