diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 1fcca7e4c3984..b016a29a86be1 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -79,25 +79,34 @@ jobs: id: set-outputs run: | if [ -z "${{ inputs.jobs }}" ]; then - pyspark=true; sparkr=true; tpcds=true; docker=true; pyspark_modules=`cd dev && python -c "import sparktestsupport.modules as m; print(','.join(m.name for m in m.all_modules if m.name.startswith('pyspark')))"` pyspark=`./dev/is-changed.py -m $pyspark_modules` - sparkr=`./dev/is-changed.py -m sparkr` - tpcds=`./dev/is-changed.py -m sql` - docker=`./dev/is-changed.py -m docker-integration-tests` - # 'build', 'scala-213', and 'java-11-17' are always true for now. - # It does not save significant time and most of PRs trigger the build. + if [[ "${{ github.repository }}" != 'apache/spark' ]]; then + pandas=$pyspark + kubernetes=`./dev/is-changed.py -m kubernetes` + sparkr=`./dev/is-changed.py -m sparkr` + tpcds=`./dev/is-changed.py -m sql` + docker=`./dev/is-changed.py -m docker-integration-tests` + else + pandas=false + kubernetes=false + sparkr=false + tpcds=false + docker=false + fi + build=`./dev/is-changed.py -m "core,unsafe,kvstore,avro,utils,network-common,network-shuffle,repl,launcher,examples,sketch,graphx,catalyst,hive-thriftserver,streaming,sql-kafka-0-10,streaming-kafka-0-10,mllib-local,mllib,yarn,mesos,kubernetes,hadoop-cloud,spark-ganglia-lgpl,sql,hive"` precondition=" { - \"build\": \"true\", + \"build\": \"$build\", \"pyspark\": \"$pyspark\", + \"pyspark-pandas\": \"$pandas\", \"sparkr\": \"$sparkr\", \"tpcds-1g\": \"$tpcds\", \"docker-integration-tests\": \"$docker\", - \"scala-213\": \"true\", - \"java-11-17\": \"true\", + \"scala-213\": \"$build\", + \"java-11-17\": \"$build\", \"lint\" : \"true\", - \"k8s-integration-tests\" : \"true\", + \"k8s-integration-tests\" : \"$kubernetes\", \"breaking-changes-buf\" : \"true\", }" echo $precondition # For debugging @@ -204,6 +213,8 @@ jobs: HIVE_PROFILE: ${{ matrix.hive }} GITHUB_PREV_SHA: ${{ github.event.before }} SPARK_LOCAL_IP: localhost + SKIP_UNIDOC: true + SKIP_MIMA: true SKIP_PACKAGING: true steps: - name: Checkout Spark repository @@ -256,7 +267,7 @@ jobs: - name: Install Python packages (Python 3.8) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) run: | - python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.56.0' 'protobuf==3.20.3' + python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas scipy unittest-xml-reporting 'grpcio==1.56.0' 'protobuf==3.20.3' python3.8 -m pip list # Run the tests. - name: Run tests @@ -360,6 +371,14 @@ jobs: pyspark-pandas-connect - >- pyspark-pandas-slow-connect + exclude: + # Always run if pyspark-pandas == 'true', even infra-image is skip (such as non-master job) + # In practice, the build will run in individual PR, but not against the individual commit + # in Apache Spark repository. + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas' }} + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas-slow' }} + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas-connect' }} + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas-slow-connect' }} env: MODULES_TO_TEST: ${{ matrix.modules }} HADOOP_PROFILE: ${{ inputs.hadoop }} @@ -407,6 +426,8 @@ jobs: key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | pyspark-coursier- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: @@ -504,6 +525,8 @@ jobs: key: sparkr-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | sparkr-coursier- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java ${{ inputs.java }} uses: actions/setup-java@v3 with: @@ -612,6 +635,8 @@ jobs: key: docs-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | docs-maven- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java 8 uses: actions/setup-java@v3 with: @@ -621,6 +646,8 @@ jobs: run: ./dev/check-license - name: Dependencies test run: ./dev/test-dependencies.sh + - name: MIMA test + run: ./dev/mima - name: Scala linter run: ./dev/lint-scala - name: Java linter @@ -672,16 +699,16 @@ jobs: # See also https://issues.apache.org/jira/browse/SPARK-35375. # Pin the MarkupSafe to 2.0.1 to resolve the CI error. # See also https://issues.apache.org/jira/browse/SPARK-38279. - python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0' + python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme 'sphinx-copybutton==0.5.2' 'nbsphinx==0.9.3' numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0' 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' 'nest-asyncio==1.5.8' 'rpds-py==0.16.2' 'alabaster==0.7.13' python3.9 -m pip install ipython_genutils # See SPARK-38517 - python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' + python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas 'plotly>=4.8' python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421 apt-get update -y apt-get install -y ruby ruby-dev Rscript -e "install.packages(c('devtools', 'testthat', 'knitr', 'rmarkdown', 'markdown', 'e1071', 'roxygen2', 'ggplot2', 'mvtnorm', 'statmod'), repos='https://cloud.r-project.org/')" Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" - gem install bundler + gem install bundler -v 2.4.22 cd docs bundle install - name: R linter @@ -1010,9 +1037,7 @@ jobs: - name: start minikube run: | # See more in "Installation" https://minikube.sigs.k8s.io/docs/start/ - # curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube-linux-amd64 - # TODO(SPARK-44495): Resume to use the latest minikube for k8s-integration-tests. - curl -LO https://storage.googleapis.com/minikube/releases/v1.30.1/minikube-linux-amd64 + curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube-linux-amd64 sudo install minikube-linux-amd64 /usr/local/bin/minikube # Github Action limit cpu:2, memory: 6947MB, limit to 2U6G for better resource statistic minikube start --cpus 2 --memory 6144 @@ -1030,7 +1055,7 @@ jobs: kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.7.0/installer/volcano-development.yaml || true eval $(minikube docker-env) - build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" + build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files if: failure() uses: actions/upload-artifact@v3 diff --git a/.gitignore b/.gitignore index 064b502175b79..06c6660900d66 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ .scala_dependencies .settings .vscode +artifacts/ /lib/ R-unit-tests.log R/unit-tests.out @@ -50,6 +51,7 @@ dev/create-release/*final dev/create-release/*txt dev/pr-deps/ dist/ +docs/_generated/ docs/_site/ docs/api docs/.local_ruby_bundle diff --git a/LICENSE b/LICENSE index 1735d3208f2e2..74686d7ffa388 100644 --- a/LICENSE +++ b/LICENSE @@ -218,11 +218,6 @@ docs/js/vendor/bootstrap.js connector/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java -Python Software Foundation License ----------------------------------- - -python/docs/source/_static/copybutton.js - BSD 3-Clause ------------ diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 1c093a4a98046..8657755b8d0ea 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 3.5.0 +Version: 3.5.4 Title: R Front End for 'Apache Spark' Description: Provides an R Front end for 'Apache Spark' . Authors@R: diff --git a/assembly/pom.xml b/assembly/pom.xml index a0aca22eab91d..3367c1629c578 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml @@ -159,6 +159,12 @@ org.apache.spark spark-connect_${scala.binary.version} ${project.version} + + + org.apache.spark + spark-connect-common_${scala.binary.version} + + org.apache.spark diff --git a/binder/Dockerfile b/binder/Dockerfile new file mode 100644 index 0000000000000..6e3dd9155fb7a --- /dev/null +++ b/binder/Dockerfile @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM python:3.10-slim +# install the notebook package +RUN pip install --no-cache notebook jupyterlab + +# create user with a home directory +ARG NB_USER +ARG NB_UID +ENV USER ${NB_USER} +ENV HOME /home/${NB_USER} + +RUN adduser --disabled-password \ + --gecos "Default user" \ + --uid ${NB_UID} \ + ${NB_USER} +WORKDIR ${HOME} +USER ${USER} + +# Make sure the contents of our repo are in ${HOME} +COPY . ${HOME} +USER root +RUN chown -R ${NB_UID} ${HOME} +RUN apt-get update && apt-get install -y openjdk-17-jre git coreutils +USER ${NB_USER} + +RUN binder/postBuild + diff --git a/binder/apt.txt b/binder/apt.txt deleted file mode 100644 index 3d86667d4b910..0000000000000 --- a/binder/apt.txt +++ /dev/null @@ -1,2 +0,0 @@ -openjdk-8-jre -git diff --git a/binder/postBuild b/binder/postBuild old mode 100644 new mode 100755 index 70ae23b393707..c17816d4a5009 --- a/binder/postBuild +++ b/binder/postBuild @@ -20,8 +20,13 @@ # This file is used for Binder integration to install PySpark available in # Jupyter notebook. +# SPARK-45706: Should fail fast. Otherwise, the Binder image is successfully +# built, and it cannot be rebuilt. +set -o pipefail +set -e + VERSION=$(python -c "exec(open('python/pyspark/version.py').read()); print(__version__)") -TAG=$(git describe --tags --exact-match 2>/dev/null) +TAG=$(git describe --tags --exact-match 2> /dev/null || true) # If a commit is tagged, exactly specified version of pyspark should be installed to avoid # a kind of accident that an old version of pyspark is installed in the live notebook environment. @@ -33,9 +38,9 @@ else fi if [[ ! $VERSION < "3.4.0" ]]; then - pip install plotly "pandas<2.0.0" "pyspark[sql,ml,mllib,pandas_on_spark,connect]$SPECIFIER$VERSION" + pip install plotly "pandas<2.0.0" "numpy>=1.15,<2" "pyspark[sql,ml,mllib,pandas_on_spark,connect]$SPECIFIER$VERSION" else - pip install plotly "pandas<2.0.0" "pyspark[sql,ml,mllib,pandas_on_spark]$SPECIFIER$VERSION" + pip install plotly "pandas<2.0.0" "numpy>=1.15,<2" "pyspark[sql,ml,mllib,pandas_on_spark]$SPECIFIER$VERSION" fi # Set 'PYARROW_IGNORE_TIMEZONE' to surpress warnings from PyArrow. diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index ce180f49ff128..014ff5bbaf209 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml @@ -66,6 +66,11 @@ commons-io test + + org.apache.commons + commons-lang3 + test + org.apache.logging.log4j diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8da48076a43aa..ed2352fd1276e 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 4a0a156699852..40825e06b82fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -325,7 +325,10 @@ public TransportResponseHandler getHandler() { @Override public void close() { - // close is a local operation and should finish with milliseconds; timeout just to be safe + // Mark the connection as timed out, so we do not return a connection that's being closed + // from the TransportClientFactory if closing takes some time (e.g. with SSL) + this.timedOut = true; + // close should not take this long; use a timeout just to be safe channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index 078d9ceb317b8..ee558bce7dab9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -41,16 +41,21 @@ * Exchange, using a pre-shared key to derive an AES-GCM key encrypting key. */ class AuthEngine implements Closeable { + public static final byte[] DERIVED_KEY_INFO = "derivedKey".getBytes(UTF_8); public static final byte[] INPUT_IV_INFO = "inputIv".getBytes(UTF_8); public static final byte[] OUTPUT_IV_INFO = "outputIv".getBytes(UTF_8); private static final String MAC_ALGORITHM = "HMACSHA256"; + private static final String LEGACY_CIPHER_ALGORITHM = "AES/CTR/NoPadding"; + private static final String CIPHER_ALGORITHM = "AES/GCM/NoPadding"; private static final int AES_GCM_KEY_SIZE_BYTES = 16; private static final byte[] EMPTY_TRANSCRIPT = new byte[0]; + private static final int UNSAFE_SKIP_HKDF_VERSION = 1; private final String appId; private final byte[] preSharedSecret; private final TransportConf conf; private final Properties cryptoConf; + private final boolean unsafeSkipFinalHkdf; private byte[] clientPrivateKey; private TransportCipher sessionCipher; @@ -62,6 +67,9 @@ class AuthEngine implements Closeable { this.preSharedSecret = preSharedSecret.getBytes(UTF_8); this.conf = conf; this.cryptoConf = conf.cryptoConf(); + // This is for backward compatibility with version 1.0 of this protocol, + // which did not perform a final HKDF round. + this.unsafeSkipFinalHkdf = conf.authEngineVersion() == UNSAFE_SKIP_HKDF_VERSION; } @VisibleForTesting @@ -201,6 +209,13 @@ private TransportCipher generateTransportCipher( byte[] sharedSecret, boolean isClient, byte[] transcript) throws GeneralSecurityException { + byte[] derivedKey = unsafeSkipFinalHkdf ? sharedSecret : // This is for backwards compatibility + Hkdf.computeHkdf( + MAC_ALGORITHM, + sharedSecret, + transcript, + DERIVED_KEY_INFO, + AES_GCM_KEY_SIZE_BYTES); byte[] clientIv = Hkdf.computeHkdf( MAC_ALGORITHM, sharedSecret, @@ -213,13 +228,20 @@ private TransportCipher generateTransportCipher( transcript, // Passing this as the HKDF salt OUTPUT_IV_INFO, // This is the HKDF info field used to differentiate IV values AES_GCM_KEY_SIZE_BYTES); - SecretKeySpec sessionKey = new SecretKeySpec(sharedSecret, "AES"); - return new TransportCipher( - cryptoConf, - conf.cipherTransformation(), - sessionKey, - isClient ? clientIv : serverIv, // If it's the client, use the client IV first - isClient ? serverIv : clientIv); + SecretKeySpec sessionKey = new SecretKeySpec(derivedKey, "AES"); + if (LEGACY_CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) { + return new CtrTransportCipher( + cryptoConf, + sessionKey, + isClient ? clientIv : serverIv, // If it's the client, use the client IV first + isClient ? serverIv : clientIv); + } else if (CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) { + return new GcmTransportCipher(sessionKey); + } else { + throw new IllegalArgumentException( + String.format("Unsupported cipher mode: %s. %s and %s are supported.", + conf.cipherTransformation(), CIPHER_ALGORITHM, LEGACY_CIPHER_ALGORITHM)); + } } private byte[] getTranscript(AuthMessage... encryptedPublicKeys) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java new file mode 100644 index 0000000000000..85b893751b39c --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; + +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; + +/** + * Cipher for encryption and decryption. + */ +public class CtrTransportCipher implements TransportCipher { + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "CtrTransportEncryption"; + private static final String DECRYPTION_HANDLER_NAME = "CtrTransportDecryption"; + @VisibleForTesting + static final int STREAM_BUFFER_SIZE = 1024 * 32; + + private final Properties conf; + private static final String CIPHER_ALGORITHM = "AES/CTR/NoPadding"; + private final SecretKeySpec key; + private final byte[] inIv; + private final byte[] outIv; + + public CtrTransportCipher( + Properties conf, + SecretKeySpec key, + byte[] inIv, + byte[] outIv) { + this.conf = conf; + this.key = key; + this.inIv = inIv; + this.outIv = outIv; + } + + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(key); + } + + @VisibleForTesting + SecretKeySpec getKey() { + return key; + } + + /** The IV for the input channel (i.e. output channel of the remote side). */ + public byte[] getInputIv() { + return inIv; + } + + /** The IV for the output channel (i.e. input channel of the remote side). */ + public byte[] getOutputIv() { + return outIv; + } + + @VisibleForTesting + CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(CIPHER_ALGORITHM, conf, ch, key, new IvParameterSpec(outIv)); + } + + @VisibleForTesting + CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(CIPHER_ALGORITHM, conf, ch, key, new IvParameterSpec(inIv)); + } + + /** + * Add handlers to channel. + * + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); + } + + @VisibleForTesting + static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteEncChannel; + private final CryptoOutputStream cos; + private final ByteArrayWritableChannel byteRawChannel; + private boolean isCipherValid; + + EncryptionHandler(CtrTransportCipher cipher) throws IOException { + byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteEncChannel); + byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + isCipherValid = true; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(createEncryptedMessage(msg), promise); + } + + @VisibleForTesting + EncryptedMessage createEncryptedMessage(Object msg) { + return new EncryptedMessage(this, cos, msg, byteEncChannel, byteRawChannel); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + if (isCipherValid) { + cos.close(); + } + } finally { + super.close(ctx, promise); + } + } + + /** + * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher + * after an error occurs. + */ + void reportError() { + this.isCipherValid = false; + } + + boolean isCipherValid() { + return isCipherValid; + } + } + + private static class DecryptionHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + private boolean isCipherValid; + + DecryptionHandler(CtrTransportCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + isCipherValid = true; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + ByteBuf buffer = (ByteBuf) data; + + try { + if (!isCipherValid) { + throw new IOException("Cipher is in invalid state."); + } + byte[] decryptedData = new byte[buffer.readableBytes()]; + byteChannel.feedData(buffer); + + int offset = 0; + while (offset < decryptedData.length) { + // SPARK-25535: workaround for CRYPTO-141. + try { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } catch (InternalError ie) { + isCipherValid = false; + throw ie; + } + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } finally { + buffer.release(); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // We do the closing of the stream / channel in handlerRemoved(...) as + // this method will be called in all cases: + // + // - when the Channel becomes inactive + // - when the handler is removed from the ChannelPipeline + try { + if (isCipherValid) { + cis.close(); + } + } finally { + super.handlerRemoved(ctx); + } + } + } + + @VisibleForTesting + static class EncryptedMessage extends AbstractFileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private final CryptoOutputStream cos; + private final EncryptionHandler handler; + private final long count; + private long transferred; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private final ByteArrayWritableChannel byteEncChannel; + private final ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage( + EncryptionHandler handler, + CryptoOutputStream cos, + Object msg, + ByteArrayWritableChannel byteEncChannel, + ByteArrayWritableChannel byteRawChannel) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.handler = handler; + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.cos = cos; + this.byteEncChannel = byteEncChannel; + this.byteRawChannel = byteRawChannel; + this.count = isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long count() { + return count; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public EncryptedMessage touch(Object o) { + super.touch(o); + if (region != null) { + region.touch(o); + } + if (buf != null) { + buf.touch(o); + } + return this; + } + + @Override + public EncryptedMessage retain(int increment) { + super.retain(increment); + if (region != null) { + region.retain(increment); + } + if (buf != null) { + buf.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (region != null) { + region.release(decrement); + } + if (buf != null) { + buf.release(decrement); + } + return super.release(decrement); + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transferred(), "Invalid position."); + + if (transferred == count) { + return 0; + } + + long totalBytesWritten = 0L; + do { + if (currentEncrypted == null) { + encryptMore(); + } + + long remaining = currentEncrypted.remaining(); + if (remaining == 0) { + // Just for safety to avoid endless loop. It usually won't happen, but since the + // underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for + // safety. + currentEncrypted = null; + byteEncChannel.reset(); + return totalBytesWritten; + } + + long bytesWritten = target.write(currentEncrypted); + totalBytesWritten += bytesWritten; + transferred += bytesWritten; + if (bytesWritten < remaining) { + // break as the underlying buffer in "target" is full + break; + } + currentEncrypted = null; + byteEncChannel.reset(); + } while (transferred < count); + + return totalBytesWritten; + } + + private void encryptMore() throws IOException { + if (!handler.isCipherValid()) { + throw new IOException("Cipher is in invalid state."); + } + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transferred()); + } + + try { + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + } catch (InternalError ie) { + handler.reportError(); + throw ie; + } + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java new file mode 100644 index 0000000000000..d3f1bf490d3a3 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Longs; +import com.google.crypto.tink.subtle.*; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.ReferenceCounted; +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteBufferWriteableChannel; + +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; +import java.security.InvalidAlgorithmParameterException; + +public class GcmTransportCipher implements TransportCipher { + private static final String HKDF_ALG = "HmacSha256"; + private static final int LENGTH_HEADER_BYTES = 8; + @VisibleForTesting + static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024; // 32KB + private final SecretKeySpec aesKey; + + public GcmTransportCipher(SecretKeySpec aesKey) { + this.aesKey = aesKey; + } + + AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException { + return new AesGcmHkdfStreaming( + aesKey.getEncoded(), + HKDF_ALG, + aesKey.getEncoded().length, + CIPHERTEXT_BUFFER_SIZE, + 0); + } + + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(aesKey); + } + + @VisibleForTesting + EncryptionHandler getEncryptionHandler() throws GeneralSecurityException { + return new EncryptionHandler(); + } + + @VisibleForTesting + DecryptionHandler getDecryptionHandler() throws GeneralSecurityException { + return new DecryptionHandler(); + } + + public void addToChannel(Channel ch) throws GeneralSecurityException { + ch.pipeline() + .addFirst("GcmTransportEncryption", getEncryptionHandler()) + .addFirst("GcmTransportDecryption", getDecryptionHandler()); + } + + @VisibleForTesting + class EncryptionHandler extends ChannelOutboundHandlerAdapter { + private final ByteBuffer plaintextBuffer; + private final ByteBuffer ciphertextBuffer; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; + + EncryptionHandler() throws InvalidAlgorithmParameterException { + aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); + ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage( + aesGcmHkdfStreaming, + msg, + plaintextBuffer, + ciphertextBuffer); + ctx.write(encryptedMessage, promise); + } + } + + static class GcmEncryptedMessage extends AbstractFileRegion { + private final Object plaintextMessage; + private final ByteBuffer plaintextBuffer; + private final ByteBuffer ciphertextBuffer; + private final ByteBuffer headerByteBuffer; + private final long bytesToRead; + private long bytesRead = 0; + private final StreamSegmentEncrypter encrypter; + private long transferred = 0; + private final long encryptedCount; + + GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, + Object plaintextMessage, + ByteBuffer plaintextBuffer, + ByteBuffer ciphertextBuffer) throws GeneralSecurityException { + Preconditions.checkArgument( + plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion, + "Unrecognized message type: %s", plaintextMessage.getClass().getName()); + this.plaintextMessage = plaintextMessage; + this.plaintextBuffer = plaintextBuffer; + this.ciphertextBuffer = ciphertextBuffer; + // If the ciphertext buffer cannot be fully written the target, transferTo may + // return with it containing some unwritten data. The initial call we'll explicitly + // set its limit to 0 to indicate the first call to transferTo. + ((Buffer) this.ciphertextBuffer).limit(0); + this.bytesToRead = getReadableBytes(); + this.encryptedCount = + LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead); + byte[] lengthAad = Longs.toByteArray(encryptedCount); + this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad); + this.headerByteBuffer = createHeaderByteBuffer(); + } + + // The format of the output is: + // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + private ByteBuffer createHeaderByteBuffer() { + ByteBuffer encrypterHeader = encrypter.getHeader(); + ByteBuffer output = ByteBuffer + .allocate(encrypterHeader.remaining() + LENGTH_HEADER_BYTES) + .putLong(encryptedCount) + .put(encrypterHeader); + ((Buffer) output).flip(); + return output; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public long count() { + return encryptedCount; + } + + @Override + public GcmEncryptedMessage touch(Object o) { + super.touch(o); + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + byteBuf.touch(o); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + fileRegion.touch(o); + } + return this; + } + + @Override + public GcmEncryptedMessage retain(int increment) { + super.retain(increment); + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + byteBuf.retain(increment); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + fileRegion.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + byteBuf.release(decrement); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + fileRegion.release(decrement); + } + return super.release(decrement); + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + int transferredThisCall = 0; + // If the header has is not empty, try to write it out to the target. + if (headerByteBuffer.hasRemaining()) { + int written = target.write(headerByteBuffer); + transferredThisCall += written; + this.transferred += written; + if (headerByteBuffer.hasRemaining()) { + return written; + } + } + // If the ciphertext buffer is not empty, try to write it to the target. + if (ciphertextBuffer.hasRemaining()) { + int written = target.write(ciphertextBuffer); + transferredThisCall += written; + this.transferred += written; + if (ciphertextBuffer.hasRemaining()) { + return transferredThisCall; + } + } + while (bytesRead < bytesToRead) { + long readableBytes = getReadableBytes(); + int readLimit = + (int) Math.min(readableBytes, plaintextBuffer.remaining()); + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + Preconditions.checkState(0 == plaintextBuffer.position()); + ((Buffer) plaintextBuffer).limit(readLimit); + byteBuf.readBytes(plaintextBuffer); + Preconditions.checkState(readLimit == plaintextBuffer.position()); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + ByteBufferWriteableChannel plaintextChannel = + new ByteBufferWriteableChannel(plaintextBuffer); + long plaintextRead = + fileRegion.transferTo(plaintextChannel, fileRegion.transferred()); + if (plaintextRead < readLimit) { + // If we do not read a full plaintext buffer or all the available + // readable bytes, return what was transferred this call. + return transferredThisCall; + } + } + boolean lastSegment = getReadableBytes() == 0; + ((Buffer) plaintextBuffer).flip(); + bytesRead += plaintextBuffer.remaining(); + ((Buffer) ciphertextBuffer).clear(); + try { + encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer); + } catch (GeneralSecurityException e) { + throw new IllegalStateException("GeneralSecurityException from encrypter", e); + } + ((Buffer) plaintextBuffer).clear(); + ((Buffer) ciphertextBuffer).flip(); + int written = target.write(ciphertextBuffer); + transferredThisCall += written; + this.transferred += written; + if (ciphertextBuffer.hasRemaining()) { + // In this case, upon calling transferTo again, it will try to write the + // remaining ciphertext buffer in the conditional before this loop. + return transferredThisCall; + } + } + return transferredThisCall; + } + + private long getReadableBytes() { + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + return byteBuf.readableBytes(); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + return fileRegion.count() - fileRegion.transferred(); + } else { + throw new IllegalArgumentException("Unsupported message type: " + + plaintextMessage.getClass().getName()); + } + } + + @Override + protected void deallocate() { + if (plaintextMessage instanceof ReferenceCounted) { + ((ReferenceCounted) plaintextMessage).release(); + } + plaintextBuffer.clear(); + ciphertextBuffer.clear(); + } + } + + @VisibleForTesting + class DecryptionHandler extends ChannelInboundHandlerAdapter { + private final ByteBuffer expectedLengthBuffer; + private final ByteBuffer headerBuffer; + private final ByteBuffer ciphertextBuffer; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; + private final StreamSegmentDecrypter decrypter; + private final int plaintextSegmentSize; + private boolean decrypterInit = false; + private boolean completed = false; + private int segmentNumber = 0; + private long expectedLength = -1; + private long ciphertextRead = 0; + + DecryptionHandler() throws GeneralSecurityException { + aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); + headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength()); + ciphertextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); + decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize(); + } + + private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { + if (expectedLength < 0) { + ciphertextNettyBuf.readBytes(expectedLengthBuffer); + if (expectedLengthBuffer.hasRemaining()) { + // We did not read enough bytes to initialize the expected length. + return false; + } + ((Buffer) expectedLengthBuffer).flip(); + expectedLength = expectedLengthBuffer.getLong(); + if (expectedLength < 0) { + throw new IllegalStateException("Invalid expected ciphertext length."); + } + ciphertextRead += LENGTH_HEADER_BYTES; + } + return true; + } + + private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) + throws GeneralSecurityException { + // Check if the ciphertext header has been read. This contains + // the IV and other internal metadata. + if (!decrypterInit) { + ciphertextNettyBuf.readBytes(headerBuffer); + if (headerBuffer.hasRemaining()) { + // We did not read enough bytes to initialize the header. + return false; + } + ((Buffer) headerBuffer).flip(); + byte[] lengthAad = Longs.toByteArray(expectedLength); + decrypter.init(headerBuffer, lengthAad); + decrypterInit = true; + ciphertextRead += aesGcmHkdfStreaming.getHeaderLength(); + if (expectedLength == ciphertextRead) { + // If the expected length is just the header, the ciphertext is 0 length. + completed = true; + } + } + return true; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) + throws GeneralSecurityException { + Preconditions.checkArgument(ciphertextMessage instanceof ByteBuf, + "Unrecognized message type: %s", + ciphertextMessage.getClass().getName()); + ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage; + // The format of the output is: + // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + try { + if (!initalizeExpectedLength(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize the expected length. + return; + } + if (!initalizeDecrypter(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize a header, needed to + // initialize a decrypter. + return; + } + int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + while (nettyBufReadableBytes > 0 && !completed) { + // Read the ciphertext into the local buffer + int readableBytes = Integer.min( + nettyBufReadableBytes, + ciphertextBuffer.remaining()); + int expectedRemaining = (int) (expectedLength - ciphertextRead); + int bytesToRead = Integer.min(readableBytes, expectedRemaining); + // The smallest ciphertext size is 16 bytes for the auth tag + ((Buffer) ciphertextBuffer).limit( + ((Buffer) ciphertextBuffer).position() + bytesToRead); + ciphertextNettyBuf.readBytes(ciphertextBuffer); + ciphertextRead += bytesToRead; + // Check if this is the last segment + if (ciphertextRead == expectedLength) { + completed = true; + } else if (ciphertextRead > expectedLength) { + throw new IllegalStateException("Read more ciphertext than expected."); + } + // If the ciphertext buffer is full, or this is the last segment, + // then decrypt it and fire a read. + if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { + ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); + ((Buffer) ciphertextBuffer).flip(); + decrypter.decryptSegment( + ciphertextBuffer, + segmentNumber, + completed, + plaintextBuffer); + segmentNumber++; + // Clear the ciphertext buffer because it's been read + ((Buffer) ciphertextBuffer).clear(); + ((Buffer) plaintextBuffer).flip(); + ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); + } else { + // Set the ciphertext buffer up to read the next chunk + ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.capacity()); + } + nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + } + } finally { + ciphertextNettyBuf.release(); + } + } + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md index 78e7459b9995d..5d3584d80462c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md @@ -1,6 +1,9 @@ -Forward Secure Auth Protocol +Forward Secure Auth Protocol v2.0 ============================================== +Summary +------- + This file describes a forward secure authentication protocol which may be used by Spark. This protocol is essentially ephemeral Diffie-Hellman key exchange using Curve25519, referred to as X25519. @@ -77,6 +80,7 @@ Now that the server has the client's ephemeral public key, it can generate its o keypair and compute a shared secret. sharedSecret = X25519.computeSharedSecret(clientPublicKey, serverKeyPair.privateKey()) + derivedKey = HKDF(sharedSecret, salt=transcript, info="deriveKey") With the shared secret, the server will also generate two initialization vectors to be used for inbound and outbound streams. These IVs are not secret and will be bound to the preceding protocol @@ -99,3 +103,14 @@ sessions. It would, however, allow impersonation of future sessions. In the event of a pre-shared key compromise, messages would still be confidential from a passive observer. Only active adversaries spoofing a session would be able to recover plaintext. +Security Changes & Compatibility +------------- + +The original version of this protocol, retroactively called v1.0, did not apply an HKDF to `sharedSecret` to derive +a key (i.e. `derivedKey`) and was directly using the encoded X coordinate as key material. This is atypical and +standard practice is to pass that shared coordinate through an HKDF. The latest version adds this additional +HKDF to derive `derivedKey`. + +Consequently, Apache Spark instances using v1.0 of this protocol will not negotiate the same key as +instances using v2.0 and will be **unable to send encrypted RPCs** across incompatible versions. For this reason, v1.0 +remains the default to preserve backward-compatibility. \ No newline at end of file diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index b507f911fe11a..355c552720185 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -17,362 +17,32 @@ package org.apache.spark.network.crypto; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; -import java.util.Properties; -import javax.crypto.spec.SecretKeySpec; -import javax.crypto.spec.IvParameterSpec; - import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.*; -import org.apache.commons.crypto.stream.CryptoInputStream; -import org.apache.commons.crypto.stream.CryptoOutputStream; - -import org.apache.spark.network.util.AbstractFileRegion; -import org.apache.spark.network.util.ByteArrayReadableChannel; -import org.apache.spark.network.util.ByteArrayWritableChannel; - -/** - * Cipher for encryption and decryption. - */ -public class TransportCipher { - @VisibleForTesting - static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption"; - private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption"; - @VisibleForTesting - static final int STREAM_BUFFER_SIZE = 1024 * 32; - - private final Properties conf; - private final String cipher; - private final SecretKeySpec key; - private final byte[] inIv; - private final byte[] outIv; - - public TransportCipher( - Properties conf, - String cipher, - SecretKeySpec key, - byte[] inIv, - byte[] outIv) { - this.conf = conf; - this.cipher = cipher; - this.key = key; - this.inIv = inIv; - this.outIv = outIv; - } - - public String getCipherTransformation() { - return cipher; - } - - @VisibleForTesting - SecretKeySpec getKey() { - return key; - } - - /** The IV for the input channel (i.e. output channel of the remote side). */ - public byte[] getInputIv() { - return inIv; - } - - /** The IV for the output channel (i.e. input channel of the remote side). */ - public byte[] getOutputIv() { - return outIv; - } - - @VisibleForTesting - CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { - return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv)); - } - - @VisibleForTesting - CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { - return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv)); - } - - /** - * Add handlers to channel. - * - * @param ch the channel for adding handlers - * @throws IOException - */ - public void addToChannel(Channel ch) throws IOException { - ch.pipeline() - .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) - .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); - } - - @VisibleForTesting - static class EncryptionHandler extends ChannelOutboundHandlerAdapter { - private final ByteArrayWritableChannel byteEncChannel; - private final CryptoOutputStream cos; - private final ByteArrayWritableChannel byteRawChannel; - private boolean isCipherValid; - - EncryptionHandler(TransportCipher cipher) throws IOException { - byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); - cos = cipher.createOutputStream(byteEncChannel); - byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); - isCipherValid = true; - } +import com.google.crypto.tink.subtle.Hex; +import com.google.crypto.tink.subtle.Hkdf; +import io.netty.channel.Channel; - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - ctx.write(createEncryptedMessage(msg), promise); - } - - @VisibleForTesting - EncryptedMessage createEncryptedMessage(Object msg) { - return new EncryptedMessage(this, cos, msg, byteEncChannel, byteRawChannel); - } +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - try { - if (isCipherValid) { - cos.close(); - } - } finally { - super.close(ctx, promise); - } - } +interface TransportCipher { + String getKeyId() throws GeneralSecurityException; + void addToChannel(Channel channel) throws IOException, GeneralSecurityException; +} - /** - * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher - * after an error occurs. +class TransportCipherUtil { + /* + * This method is used for testing to verify key derivation. */ - void reportError() { - this.isCipherValid = false; - } - - boolean isCipherValid() { - return isCipherValid; - } - } - - private static class DecryptionHandler extends ChannelInboundHandlerAdapter { - private final CryptoInputStream cis; - private final ByteArrayReadableChannel byteChannel; - private boolean isCipherValid; - - DecryptionHandler(TransportCipher cipher) throws IOException { - byteChannel = new ByteArrayReadableChannel(); - cis = cipher.createInputStream(byteChannel); - isCipherValid = true; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { - ByteBuf buffer = (ByteBuf) data; - - try { - if (!isCipherValid) { - throw new IOException("Cipher is in invalid state."); - } - byte[] decryptedData = new byte[buffer.readableBytes()]; - byteChannel.feedData(buffer); - - int offset = 0; - while (offset < decryptedData.length) { - // SPARK-25535: workaround for CRYPTO-141. - try { - offset += cis.read(decryptedData, offset, decryptedData.length - offset); - } catch (InternalError ie) { - isCipherValid = false; - throw ie; - } - } - - ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); - } finally { - buffer.release(); - } - } - - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - // We do the closing of the stream / channel in handlerRemoved(...) as - // this method will be called in all cases: - // - // - when the Channel becomes inactive - // - when the handler is removed from the ChannelPipeline - try { - if (isCipherValid) { - cis.close(); - } - } finally { - super.handlerRemoved(ctx); - } - } - } - - @VisibleForTesting - static class EncryptedMessage extends AbstractFileRegion { - private final boolean isByteBuf; - private final ByteBuf buf; - private final FileRegion region; - private final CryptoOutputStream cos; - private final EncryptionHandler handler; - private final long count; - private long transferred; - - // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has - // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data - // from upper handler, another is used to store encrypted data. - private final ByteArrayWritableChannel byteEncChannel; - private final ByteArrayWritableChannel byteRawChannel; - - private ByteBuffer currentEncrypted; - - EncryptedMessage( - EncryptionHandler handler, - CryptoOutputStream cos, - Object msg, - ByteArrayWritableChannel byteEncChannel, - ByteArrayWritableChannel byteRawChannel) { - Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, - "Unrecognized message type: %s", msg.getClass().getName()); - this.handler = handler; - this.isByteBuf = msg instanceof ByteBuf; - this.buf = isByteBuf ? (ByteBuf) msg : null; - this.region = isByteBuf ? null : (FileRegion) msg; - this.transferred = 0; - this.cos = cos; - this.byteEncChannel = byteEncChannel; - this.byteRawChannel = byteRawChannel; - this.count = isByteBuf ? buf.readableBytes() : region.count(); - } - - @Override - public long count() { - return count; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transferred() { - return transferred; - } - - @Override - public EncryptedMessage touch(Object o) { - super.touch(o); - if (region != null) { - region.touch(o); - } - if (buf != null) { - buf.touch(o); - } - return this; - } - - @Override - public EncryptedMessage retain(int increment) { - super.retain(increment); - if (region != null) { - region.retain(increment); - } - if (buf != null) { - buf.retain(increment); - } - return this; - } - - @Override - public boolean release(int decrement) { - if (region != null) { - region.release(decrement); - } - if (buf != null) { - buf.release(decrement); - } - return super.release(decrement); - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - Preconditions.checkArgument(position == transferred(), "Invalid position."); - - if (transferred == count) { - return 0; - } - - long totalBytesWritten = 0L; - do { - if (currentEncrypted == null) { - encryptMore(); - } - - long remaining = currentEncrypted.remaining(); - if (remaining == 0) { - // Just for safety to avoid endless loop. It usually won't happen, but since the - // underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for - // safety. - currentEncrypted = null; - byteEncChannel.reset(); - return totalBytesWritten; - } - - long bytesWritten = target.write(currentEncrypted); - totalBytesWritten += bytesWritten; - transferred += bytesWritten; - if (bytesWritten < remaining) { - // break as the underlying buffer in "target" is full - break; - } - currentEncrypted = null; - byteEncChannel.reset(); - } while (transferred < count); - - return totalBytesWritten; - } - - private void encryptMore() throws IOException { - if (!handler.isCipherValid()) { - throw new IOException("Cipher is in invalid state."); - } - byteRawChannel.reset(); - - if (isByteBuf) { - int copied = byteRawChannel.write(buf.nioBuffer()); - buf.skipBytes(copied); - } else { - region.transferTo(byteRawChannel, region.transferred()); - } - - try { - cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); - cos.flush(); - } catch (InternalError ie) { - handler.reportError(); - throw ie; - } - - currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), - 0, byteEncChannel.length()); - } - - @Override - protected void deallocate() { - byteRawChannel.reset(); - byteEncChannel.reset(); - if (region != null) { - region.release(); - } - if (buf != null) { - buf.release(); - } + @VisibleForTesting + static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { + byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", + key.getEncoded(), + null, + "keyID".getBytes(StandardCharsets.UTF_8), + 32); + return Hex.encode(keyIdBytes); } - } - } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java new file mode 100644 index 0000000000000..d49f46afa7ec4 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.WritableByteChannel; + +public class ByteBufferWriteableChannel implements WritableByteChannel { + private final ByteBuffer destination; + private boolean open; + + public ByteBufferWriteableChannel(ByteBuffer destination) { + this.destination = destination; + this.open = true; + } + + @Override + public int write(ByteBuffer src) throws IOException { + if (!isOpen()) { + throw new ClosedChannelException(); + } + int bytesToWrite = Math.min(src.remaining(), destination.remaining()); + // Destination buffer is full + if (bytesToWrite == 0) { + return 0; + } + ByteBuffer temp = src.slice(); + ((Buffer) temp).limit(bytesToWrite); + destination.put(temp); + ((Buffer) src).position(((Buffer) src).position() + bytesToWrite); + return bytesToWrite; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void close() { + open = false; + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 45e9994be7225..e4966b32fb454 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -212,6 +212,15 @@ public boolean encryptionEnabled() { return conf.getBoolean("spark.network.crypto.enabled", false); } + /** + * Version number to be used by the AuthEngine key agreement protocol. Valid values are 1 or 2. + * The default version is 1 for backward compatibility. Version 2 is recommended for stronger + * security properties. + */ + public int authEngineVersion() { + return conf.getInt("spark.network.crypto.authEngineVersion", 1); + } + /** * The cipher transformation to use for encrypting session data. */ diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index c6029a70bd61d..ad737e5332dd4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -18,68 +18,76 @@ package org.apache.spark.network.crypto; import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; import java.security.GeneralSecurityException; -import java.util.Random; +import java.util.Map; +import com.google.common.collect.ImmutableMap; import com.google.crypto.tink.subtle.Hex; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.FileRegion; -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.util.*; + import static org.junit.Assert.*; -import org.junit.BeforeClass; import org.junit.Test; -import static org.mockito.Mockito.*; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -public class AuthEngineSuite { - - private static final String clientPrivate = - "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; - private static final String clientChallengeHex = - "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + - "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + - "65f8c426e18ff380f6"; - private static final String serverResponseHex = - "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + - "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + - "08ecad08b46b5ee3ff"; - private static final String sharedKey = - "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; - private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; - private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; - private static TransportConf conf; - @BeforeClass - public static void setUp() { - conf = new TransportConf("rpc", MapConfigProvider.EMPTY); +abstract class AuthEngineSuite { + static final String clientPrivate = + "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; + static final String clientChallengeHex = + "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + + "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + + "65f8c426e18ff380f6"; + static final String serverResponseHex = + "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + + "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + + "08ecad08b46b5ee3ff"; + static final String derivedKeyId = + "de04fd52d71040ed9d260579dacfdf4f5695f991ce8ddb1dde05a7335880906e"; + // This key would have been derived for version 1.0 protocol that did not run a final HKDF round. + static final String unsafeDerivedKey = + "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; + static TransportConf conf; + + static TransportConf getConf(int authEngineVerison, boolean useCtr) { + String authEngineVersion = (authEngineVerison == 1) ? "1" : "2"; + String mode = useCtr ? "AES/CTR/NoPadding" : "AES/GCM/NoPadding"; + Map confMap = ImmutableMap.of( + "spark.network.crypto.enabled", "true", + "spark.network.crypto.authEngineVersion", authEngineVersion, + "spark.network.crypto.cipher", mode + ); + ConfigProvider v2Provider = new MapConfigProvider(confMap); + return new TransportConf("rpc", v2Provider); } @Test public void testAuthEngine() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); AuthMessage serverResponse = server.response(clientChallenge); client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher serverCipher = server.sessionCipher(); TransportCipher clientCipher = client.sessionCipher(); + assertEquals(clientCipher.getKeyId(), serverCipher.getKeyId()); + } + } - assertArrayEquals(serverCipher.getInputIv(), clientCipher.getOutputIv()); - assertArrayEquals(serverCipher.getOutputIv(), clientCipher.getInputIv()); - assertEquals(serverCipher.getKey(), clientCipher.getKey()); + @Test + public void testFixedChallengeResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + assertEquals(client.sessionCipher().getKeyId(), derivedKeyId); } } @Test public void testCorruptChallengeAppId() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -91,7 +99,6 @@ public void testCorruptChallengeAppId() throws Exception { @Test public void testCorruptChallengeSalt() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -102,7 +109,6 @@ public void testCorruptChallengeSalt() throws Exception { @Test public void testCorruptChallengeCiphertext() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -113,7 +119,6 @@ public void testCorruptChallengeCiphertext() throws Exception { @Test public void testCorruptResponseAppId() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -127,20 +132,18 @@ public void testCorruptResponseAppId() throws Exception { @Test public void testCorruptResponseSalt() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); AuthMessage serverResponse = server.response(clientChallenge); serverResponse.salt[0] ^= 1; assertThrows(GeneralSecurityException.class, - () -> client.deriveSessionCipher(clientChallenge, serverResponse)); + () -> client.deriveSessionCipher(clientChallenge, serverResponse)); } } @Test public void testCorruptServerCiphertext() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -162,24 +165,6 @@ public void testFixedChallenge() throws Exception { } } - @Test - public void testFixedChallengeResponse() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { - byte[] clientPrivateKey = Hex.decode(clientPrivate); - client.setClientPrivateKey(clientPrivateKey); - AuthMessage clientChallenge = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); - AuthMessage serverResponse = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); - // Verify that the client will accept an old transcript. - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = client.sessionCipher(); - assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), sharedKey); - assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv); - assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv); - } - } - @Test public void testMismatchedSecret() throws Exception { try (AuthEngine client = new AuthEngine("appId", "secret", conf); @@ -188,70 +173,4 @@ public void testMismatchedSecret() throws Exception { assertThrows(GeneralSecurityException.class, () -> server.response(clientChallenge)); } } - - @Test - public void testEncryptedMessage() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher cipher = server.sessionCipher(); - TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); - - byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1]; - new Random().nextBytes(data); - ByteBuf buf = Unpooled.wrappedBuffer(data); - - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); - TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); - while (emsg.transferred() < emsg.count()) { - emsg.transferTo(channel, emsg.transferred()); - } - assertEquals(data.length, channel.length()); - } - } - - @Test - public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher cipher = server.sessionCipher(); - TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); - - int testDataLength = 4; - FileRegion region = mock(FileRegion.class); - when(region.count()).thenReturn((long) testDataLength); - // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. - when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { - - private boolean firstTime = true; - - @Override - public Long answer(InvocationOnMock invocationOnMock) throws Throwable { - if (firstTime) { - firstTime = false; - return 0L; - } else { - WritableByteChannel channel = invocationOnMock.getArgument(0); - channel.write(ByteBuffer.wrap(new byte[testDataLength])); - return (long) testDataLength; - } - } - }); - - TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); - // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. - assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); - assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); - assertEquals(emsg.transferred(), emsg.count()); - assertEquals(4, channel.length()); - } - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 4a5b426b1158a..ad8bbdb4c2655 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -49,7 +49,7 @@ public class AuthIntegrationSuite { private AuthTestCtx ctx; @After - public void cleanUp() throws Exception { + public void cleanUp() { if (ctx != null) { ctx.close(); } @@ -57,8 +57,8 @@ public void cleanUp() throws Exception { } @Test - public void testNewAuth() throws Exception { - ctx = new AuthTestCtx(); + public void testNewCtrAuth() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding"); ctx.createServer("secret"); ctx.createClient("secret"); @@ -68,8 +68,28 @@ public void testNewAuth() throws Exception { } @Test - public void testAuthFailure() throws Exception { - ctx = new AuthTestCtx(); + public void testNewGcmAuth() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding"); + ctx.createServer("secret"); + ctx.createClient("secret"); + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + assertNull(ctx.authRpcHandler.saslHandler); + } + + @Test + public void testCtrAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding"); + ctx.createServer("server"); + + assertThrows(Exception.class, () -> ctx.createClient("client")); + assertFalse(ctx.authRpcHandler.isAuthenticated()); + assertFalse(ctx.serverChannel.isActive()); + } + + @Test + public void testGcmAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding"); ctx.createServer("server"); assertThrows(Exception.class, () -> ctx.createClient("client")); @@ -100,7 +120,7 @@ public void testSaslClientFallback() throws Exception { } @Test - public void testAuthReplay() throws Exception { + public void testCtrAuthReplay() throws Exception { // This test covers the case where an attacker replays a challenge message sniffed from the // network, but doesn't know the actual secret. The server should close the connection as // soon as a message is sent after authentication is performed. This is emulated by removing @@ -110,16 +130,16 @@ public void testAuthReplay() throws Exception { ctx.createClient("secret"); assertNotNull(ctx.client.getChannel().pipeline() - .remove(TransportCipher.ENCRYPTION_HANDLER_NAME)); + .remove(CtrTransportCipher.ENCRYPTION_HANDLER_NAME)); assertThrows(Exception.class, () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000)); assertTrue(ctx.authRpcHandler.isAuthenticated()); } @Test - public void testLargeMessageEncryption() throws Exception { + public void testLargeCtrMessageEncryption() throws Exception { // Use a big length to create a message that cannot be put into the encryption buffer completely - final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE; + final int testErrorMessageLength = CtrTransportCipher.STREAM_BUFFER_SIZE; ctx = new AuthTestCtx(new RpcHandler() { @Override public void receive( @@ -157,6 +177,23 @@ public void testValidMergedBlockMetaReqHandler() throws Exception { assertNotNull(ctx.authRpcHandler.getMergedBlockMetaReqHandler()); } + private static class DummyRpcHandler extends RpcHandler { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String messageString = JavaUtils.bytesToString(message); + assertEquals("Ping", messageString); + callback.onSuccess(JavaUtils.stringToBytes("Pong")); + } + + @Override + public StreamManager getStreamManager() { + return null; + } + } + private static class AuthTestCtx { private final String appId = "testAppId"; @@ -169,25 +206,17 @@ private static class AuthTestCtx { volatile AuthRpcHandler authRpcHandler; AuthTestCtx() throws Exception { - this(new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - assertEquals("Ping", JavaUtils.bytesToString(message)); - callback.onSuccess(JavaUtils.stringToBytes("Pong")); - } - - @Override - public StreamManager getStreamManager() { - return null; - } - }); + this(new DummyRpcHandler()); } AuthTestCtx(RpcHandler rpcHandler) throws Exception { - Map testConf = ImmutableMap.of("spark.network.crypto.enabled", "true"); + this(rpcHandler, "AES/CTR/NoPadding"); + } + + AuthTestCtx(RpcHandler rpcHandler, String mode) throws Exception { + Map testConf = ImmutableMap.of( + "spark.network.crypto.enabled", "true", + "spark.network.crypto.cipher", mode); this.conf = new TransportConf("rpc", new MapConfigProvider(testConf)); this.ctx = new TransportContext(conf, rpcHandler); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java new file mode 100644 index 0000000000000..dcec2f17be532 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import com.google.crypto.tink.subtle.Hex; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Random; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.*; + +public class CtrAuthEngineSuite extends AuthEngineSuite { + private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; + private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; + + @Before + public void setUp() { + conf = getConf(2, true); + } + + @Test + public void testAuthEngine() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher serverCipher = server.sessionCipher(); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + assert(serverCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrClient = (CtrTransportCipher) clientCipher; + CtrTransportCipher ctrServer = (CtrTransportCipher) serverCipher; + assertArrayEquals(ctrServer.getInputIv(), ctrClient.getOutputIv()); + assertArrayEquals(ctrServer.getOutputIv(), ctrClient.getInputIv()); + assertEquals(ctrServer.getKey(), ctrClient.getKey()); + } + } + + @Test + public void testCtrFixedChallengeIvResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assertEquals(clientCipher.getKeyId(), derivedKeyId); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testFixedChallengeResponseUnsafeVersion() throws Exception { + TransportConf v1Conf = getConf(1, true); + try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getKey().getEncoded()), unsafeDerivedKey); + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testCtrEncryptedMessage() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + + byte[] data = new byte[CtrTransportCipher.STREAM_BUFFER_SIZE + 1]; + new Random().nextBytes(data); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); + while (emsg.transferred() < emsg.count()) { + emsg.transferTo(channel, emsg.transferred()); + } + assertEquals(data.length, channel.length()); + } + } + + @Test + public void testCtrEncryptedMessageWhenTransferringZeroBytes() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + int testDataLength = 4; + FileRegion region = mock(FileRegion.class); + when(region.count()).thenReturn((long) testDataLength); + // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. + when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { + + private boolean firstTime = true; + + @Override + public Long answer(InvocationOnMock invocationOnMock) throws Throwable { + if (firstTime) { + firstTime = false; + return 0L; + } else { + WritableByteChannel channel = invocationOnMock.getArgument(0); + channel.write(ByteBuffer.wrap(new byte[testDataLength])); + return (long) testDataLength; + } + } + }); + + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); + // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. + assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); + assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); + assertEquals(emsg.transferred(), emsg.count()); + assertEquals(4, channel.length()); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java new file mode 100644 index 0000000000000..f25277aa1a997 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteBufferWriteableChannel; +import org.apache.spark.network.util.TransportConf; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import javax.crypto.AEADBadTagException; +import java.io.IOException; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.*; + +public class GcmAuthEngineSuite extends AuthEngineSuite { + + @Before + public void setUp() { + // Uses GCM mode + conf = getConf(2, false); + } + + @Test + public void testGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + byte[] data = new byte[plaintextSegmentSize + (plaintextSegmentSize / 2)]; + // Just writing some bytes. + data[0] = 'a'; + data[data.length / 2] = 'b'; + data[data.length - 10] = 'c'; + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ((Buffer) ciphertextBuffer).flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)) + .fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('c', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + static class FakeRegion extends AbstractFileRegion { + private final ByteBuffer[] source; + private int sourcePosition; + private final long count; + + FakeRegion(ByteBuffer... source) { + this.source = source; + sourcePosition = 0; + count = remaining(); + } + + private long remaining() { + long remaining = 0; + for (ByteBuffer buffer : source) { + remaining += buffer.remaining(); + } + return remaining; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return count - remaining(); + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (sourcePosition < source.length) { + ByteBuffer currentBuffer = source[sourcePosition]; + long written = target.write(currentBuffer); + if (!currentBuffer.hasRemaining()) { + sourcePosition++; + } + return written; + } else { + return 0; + } + } + + @Override + protected void deallocate() { + } + } + + private static ByteBuffer getTestByteBuf(int size, byte fill) { + byte[] data = new byte[size]; + Arrays.fill(data, fill); + return ByteBuffer.wrap(data); + } + + @Test + public void testGcmEncryptedMessageFileRegion() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int halfSegmentSize = plaintextSegmentSize / 2; + int totalSize = plaintextSegmentSize + halfSegmentSize; + + // Set up some fragmented segments to test + ByteBuffer halfSegment = getTestByteBuf(halfSegmentSize, (byte) 'a'); + int smallFragmentSize = 128; + ByteBuffer smallFragment = getTestByteBuf(smallFragmentSize, (byte) 'b'); + int remainderSize = totalSize - halfSegmentSize - smallFragmentSize; + ByteBuffer remainder = getTestByteBuf(remainderSize, (byte) 'c'); + FakeRegion fakeRegion = new FakeRegion(halfSegment, smallFragment, remainder); + assertEquals(totalSize, fakeRegion.count()); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, fakeRegion, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + + // We'll simulate the FileRegion only transferring half a segment. + // The encrypted message should buffer the partial segment plaintext. + long ciphertextTransferred = 0; + while (ciphertextTransferred < encrypted.count()) { + long chunkTransferred = encrypted.transferTo(channel, 0); + ciphertextTransferred += chunkTransferred; + } + assertEquals(encrypted.count(), ciphertextTransferred); + + ((Buffer) ciphertextBuffer).flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf plaintext = captorPlaintext.getValue(); + // We expect this to be the last partial plaintext segment + int expectedLength = totalSize % plaintextSegmentSize; + assertEquals(expectedLength, plaintext.readableBytes()); + // This will be the "remainder" segment that is filled to 'c' + assertEquals('c', plaintext.getByte(0)); + } + } + + + @Test + public void testGcmUnalignedDecryption() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int plaintextSize = plaintextSegmentSize + (plaintextSegmentSize / 2); + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'x'); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ((Buffer) ciphertextBuffer).flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Split up the ciphertext into some different sized chunks + int firstChunkSize = plaintextSize / 2; + ByteBuf mockCiphertext = spy(ciphertext); + when(mockCiphertext.readableBytes()) + .thenReturn(firstChunkSize, firstChunkSize).thenCallRealMethod(); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, mockCiphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('x', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + @Test + public void testCorruptGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert (clientCipher instanceof GcmTransportCipher); + + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + byte[] zeroData = new byte[1024 * 32]; + // Just writing some bytes. + ByteBuf buf = Unpooled.wrappedBuffer(zeroData); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ((Buffer) ciphertextBuffer).flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + byte b = ciphertext.getByte(100); + // Inverting the bits of the 100th bit + ciphertext.setByte(100, ~b & 0xFF); + assertThrows(AEADBadTagException.class, () -> decryptionHandler.channelRead(ctx, ciphertext)); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java index cde5c1c1022c4..35f7886e174a9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java @@ -41,10 +41,10 @@ public class TransportCipherSuite { @Test - public void testBufferNotLeaksOnInternalError() throws IOException { + public void testCtrBufferNotLeaksOnInternalError() throws IOException { String algorithm = "TestAlgorithm"; TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY); - TransportCipher cipher = new TransportCipher(conf.cryptoConf(), conf.cipherTransformation(), + CtrTransportCipher cipher = new CtrTransportCipher(conf.cryptoConf(), new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) { @Override diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 48e64d21a58b0..b791a06aad43a 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 2bbacbe71a439..685ada5194905 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index fca31591b1ef1..b2e488c7bb222 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index a93e227655ea7..3a260a8dff53f 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 49f7e2d8c861e..fd0aa7ba2a3a2 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index a91ea2ee6b5a8..e02346c477375 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -96,7 +96,7 @@ public final class Platform { Method createMethod = cleanerClass.getMethod("create", Object.class, Runnable.class); // Accessing jdk.internal.ref.Cleaner should actually fail by default in JDK 9+, // unfortunately, unless the user has allowed access with something like - // --add-opens java.base/java.lang=ALL-UNNAMED If not, we can't really use the Cleaner + // --add-opens java.base/jdk.internal.ref=ALL-UNNAMED If not, we can't use the Cleaner // hack below. It doesn't break, just means the user might run into the default JVM limit // on off-heap memory and increase it or set the flag above. This tests whether it's // available: @@ -118,6 +118,11 @@ public final class Platform { } } + // Visible for testing + public static boolean cleanerCreateMethodIsDefined() { + return CLEANER_CREATE_METHOD != null; + } + /** * @return true when running JVM is having sun's Unsafe package available in it and underlying * system having unaligned-access capability. diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index c59878fea9941..c99f2d85f4e54 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,4 +157,11 @@ public void heapMemoryReuse() { Assert.assertEquals(1024 * 1024 + 7, onheap4.size()); Assert.assertEquals(obj3, onheap4.getBaseObject()); } + + @Test + public void cleanerCreateMethodIsDefined() { + // Regression test for SPARK-45508: we don't expect the "no cleaner" fallback + // path to be hit in normal usage. + Assert.assertTrue(Platform.cleanerCreateMethodIsDefined()); + } } diff --git a/common/utils/pom.xml b/common/utils/pom.xml index c200c06a42e69..7c87be73d7d96 100644 --- a/common/utils/pom.xml +++ b/common/utils/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java index bbe764b8366c8..d6603dcbee1ae 100644 --- a/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -120,6 +120,7 @@ public static void deleteRecursively(File file, FilenameFilter filter) throws IO private static void deleteRecursivelyUsingJavaIO( File file, FilenameFilter filter) throws IOException { + if (!file.exists()) return; BasicFileAttributes fileAttributes = Files.readAttributes(file.toPath(), BasicFileAttributes.class); if (fileAttributes.isDirectory() && !isSymlink(file)) { diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 477fe9b3f614e..f1943a8ff3e04 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -28,6 +28,15 @@ ], "sqlState" : "42702" }, + "AMBIGUOUS_COLUMN_REFERENCE" : { + "message" : [ + "Column is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same.", + "This column points to one of the DataFrame but Spark is unable to figure out which one.", + "Please alias the DataFrames with different names via `DataFrame.alias` before joining them,", + "and specify the column using qualified name, e.g. `df.alias(\"a\").join(df.alias(\"b\"), col(\"a.id\") > col(\"b.id\"))`." + ], + "sqlState" : "42702" + }, "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { "message" : [ "Lateral column alias is ambiguous and has matches." @@ -69,6 +78,11 @@ } } }, + "AVRO_INCOMPATIBLE_READ_TYPE" : { + "message" : [ + "Cannot convert Avro to SQL because the original encoded data type is , however you're trying to read the field as , which would lead to an incorrect answer. To allow reading this field, enable the SQL configuration: \"spark.sql.legacy.avro.allowIncompatibleSchema\"." + ] + }, "BATCH_METADATA_NOT_FOUND" : { "message" : [ "Unable to find batch ." @@ -1278,6 +1292,11 @@ "which requires type, but the statement provided a value of incompatible type." ] }, + "NOT_CONSTANT" : { + "message" : [ + "which is not a constant expression whose equivalent value is known at query planning time." + ] + }, "SUBQUERY_EXPRESSION" : { "message" : [ "which contains subquery expressions." @@ -3042,7 +3061,7 @@ "subClass" : { "MULTI_GENERATOR" : { "message" : [ - "only one generator allowed per clause but found : ." + "only one generator allowed per SELECT clause but found : ." ] }, "NESTED_IN_EXPRESSIONS" : { diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 0f329b5655b32..2331a8e67b28e 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -61,7 +61,7 @@ private[spark] object SparkThrowableHelper { } def isInternalError(errorClass: String): Boolean = { - errorClass.startsWith("INTERNAL_ERROR") + errorClass != null && errorClass.startsWith("INTERNAL_ERROR") } def getMessage(e: SparkThrowable with Throwable, format: ErrorMessageFormat.Value): String = { diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala index 83e01330ce3f6..bd82ce962b8d0 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala @@ -196,7 +196,7 @@ private[spark] object Logging { val initLock = new Object() try { // We use reflection here to handle the case where users remove the - // slf4j-to-jul bridge order to route their logs to JUL. + // jul-to-slf4j bridge order to route their logs to JUL. val bridgeClass = SparkClassUtils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] diff --git a/connector/avro/pom.xml b/connector/avro/pom.xml index 63b411137ed7a..8bc2802ea5d0d 100644 --- a/connector/avro/pom.xml +++ b/connector/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 59f2999bdd395..2c2a45fc3f14f 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -61,7 +61,8 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val reader = new GenericDatumReader[Any](actualSchema, expectedSchema) @transient private lazy val deserializer = - new AvroDeserializer(expectedSchema, dataType, avroOptions.datetimeRebaseModeInRead) + new AvroDeserializer(expectedSchema, dataType, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index a78ee89a3e933..ec34d10a5ffe8 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -35,8 +35,9 @@ import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArr import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -48,18 +49,21 @@ private[sql] class AvroDeserializer( rootCatalystType: DataType, positionalFieldMatch: Boolean, datetimeRebaseSpec: RebaseSpec, - filters: StructFilters) { + filters: StructFilters, + useStableIdForUnionType: Boolean) { def this( rootAvroType: Schema, rootCatalystType: DataType, - datetimeRebaseMode: String) = { + datetimeRebaseMode: String, + useStableIdForUnionType: Boolean) = { this( rootAvroType, rootCatalystType, positionalFieldMatch = false, RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), - new NoopFilters) + new NoopFilters, + useStableIdForUnionType) } private lazy val decimalConversions = new DecimalConversion() @@ -101,6 +105,9 @@ private[sql] class AvroDeserializer( s"Cannot convert Avro type $rootAvroType to SQL type ${rootCatalystType.sql}.", ise) } + private lazy val preventReadingIncorrectType = !SQLConf.get + .getConf(SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA) + def deserialize(data: Any): Option[Any] = converter(data) /** @@ -117,6 +124,8 @@ private[sql] class AvroDeserializer( val incompatibleMsg = errorPrefix + s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" + val realDataType = SchemaConverters.toSqlType(avroType, useStableIdForUnionType).dataType + (avroType.getType, catalystType) match { case (NULL, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) @@ -128,9 +137,19 @@ private[sql] class AvroDeserializer( case (INT, IntegerType) => (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int]) + case (INT, dt: DatetimeType) + if preventReadingIncorrectType && realDataType.isInstanceOf[YearMonthIntervalType] => + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) + case (INT, DateType) => (updater, ordinal, value) => updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) + case (LONG, dt: DatetimeType) + if preventReadingIncorrectType && realDataType.isInstanceOf[DayTimeIntervalType] => + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) + case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) @@ -204,17 +223,30 @@ private[sql] class AvroDeserializer( } updater.set(ordinal, bytes) - case (FIXED, _: DecimalType) => (updater, ordinal, value) => + case (FIXED, dt: DecimalType) => val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] - val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d) - val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) - updater.setDecimal(ordinal, decimal) + if (preventReadingIncorrectType && + d.getPrecision - d.getScale > dt.precision - dt.scale) { + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) + } + (updater, ordinal, value) => + val bigDecimal = + decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) - case (BYTES, _: DecimalType) => (updater, ordinal, value) => + case (BYTES, dt: DecimalType) => val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] - val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d) - val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) - updater.setDecimal(ordinal, decimal) + if (preventReadingIncorrectType && + d.getPrecision - d.getScale > dt.precision - dt.scale) { + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) + } + (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) case (RECORD, st: StructType) => // Avro datasource doesn't accept filters with nested attributes. See SPARK-32328. diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 53562a3afdb5b..7b0292df43c2f 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -141,7 +141,8 @@ private[sql] class AvroFileFormat extends FileFormat requiredSchema, parsedOptions.positionalFieldMatching, datetimeRebaseMode, - avroFilters) + avroFilters, + parsedOptions.useStableIdForUnionType) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 6f21639e28d68..af358a8d1c961 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -46,16 +46,24 @@ object SchemaConverters { */ case class SchemaType(dataType: DataType, nullable: Boolean) + /** + * Converts an Avro schema to a corresponding Spark SQL schema. + * + * @since 4.0.0 + */ + def toSqlType(avroSchema: Schema, useStableIdForUnionType: Boolean): SchemaType = { + toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType) + } /** * Converts an Avro schema to a corresponding Spark SQL schema. * * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(Map())) + toSqlType(avroSchema, false) } def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options)) + toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options).useStableIdForUnionType) } // The property specifies Catalyst type of the given field @@ -64,7 +72,7 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, existingRecordNames: Set[String], - avroOptions: AvroOptions): SchemaType = { + useStableIdForUnionType: Boolean): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -117,7 +125,7 @@ object SchemaConverters { } val newRecordNames = existingRecordNames + avroSchema.getFullName val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, avroOptions) + val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, useStableIdForUnionType) StructField(f.name, schemaType.dataType, schemaType.nullable) } @@ -127,13 +135,14 @@ object SchemaConverters { val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, - avroOptions) + useStableIdForUnionType) SchemaType( ArrayType(schemaType.dataType, containsNull = schemaType.nullable), nullable = false) case MAP => - val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames, avroOptions) + val schemaType = toSqlTypeHelper(avroSchema.getValueType, + existingRecordNames, useStableIdForUnionType) SchemaType( MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), nullable = false) @@ -143,17 +152,18 @@ object SchemaConverters { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) if (remainingUnionTypes.size == 1) { - toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, avroOptions) + toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, useStableIdForUnionType) .copy(nullable = true) } else { toSqlTypeHelper( Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames, - avroOptions).copy(nullable = true) + useStableIdForUnionType).copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => - toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames, avroOptions) + toSqlTypeHelper(avroSchema.getTypes.get(0), + existingRecordNames, useStableIdForUnionType) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -167,20 +177,20 @@ object SchemaConverters { val fieldNameSet : mutable.Set[String] = mutable.Set() val fields = avroSchema.getTypes.asScala.zipWithIndex.map { case (s, i) => - val schemaType = toSqlTypeHelper(s, existingRecordNames, avroOptions) + val schemaType = toSqlTypeHelper(s, existingRecordNames, useStableIdForUnionType) - val fieldName = if (avroOptions.useStableIdForUnionType) { + val fieldName = if (useStableIdForUnionType) { // Avro's field name may be case sensitive, so field names for two named type // could be "a" and "A" and we need to distinguish them. In this case, we throw // an exception. - val temp_name = s"member_${s.getName.toLowerCase(Locale.ROOT)}" - if (fieldNameSet.contains(temp_name)) { + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = s"member_${s.getName}" + if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { throw new IncompatibleSchemaException( - "Cannot generate stable indentifier for Avro union type due to name " + + "Cannot generate stable identifier for Avro union type due to name " + s"conflict of type name ${s.getName}") } - fieldNameSet.add(temp_name) - temp_name + tempFieldName } else { s"member$i" } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index cc7bd180e8477..2c85c1b067392 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -103,7 +103,8 @@ case class AvroPartitionReaderFactory( readDataSchema, options.positionalFieldMatching, datetimeRebaseMode, - avroFilters) + avroFilters, + options.useStableIdForUnionType) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 1cb34a0bc4dc5..250b5e0615ad8 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -59,7 +59,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite val expected = { val avroSchema = new Schema.Parser().parse(schema) - SchemaConverters.toSqlType(avroSchema).dataType match { + SchemaConverters.toSqlType(avroSchema, false).dataType match { case st: StructType => Row.fromSeq((0 until st.length).map(_ => null)) case _ => null } @@ -281,13 +281,14 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite data: GenericData.Record, expected: Option[Any], filters: StructFilters = new NoopFilters): Unit = { - val dataType = SchemaConverters.toSqlType(schema).dataType + val dataType = SchemaConverters.toSqlType(schema, false).dataType val deserializer = new AvroDeserializer( schema, dataType, false, RebaseSpec(LegacyBehaviorPolicy.CORRECTED), - filters) + filters, + false) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 70d0bc6c0ad10..965e3a0c1cba6 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -75,7 +75,8 @@ class AvroRowReaderSuite StructType(new StructField("value", IntegerType, true) :: Nil), false, RebaseSpec(CORRECTED), - new NoopFilters) + new NoopFilters, + false) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index 7f99f3c737c86..a21f3f008fdc7 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -226,7 +226,8 @@ object AvroSerdeSuite { sql, isPositional(matchType), RebaseSpec(CORRECTED), - new NoopFilters) + new NoopFilters, + false) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index d22a2d3697579..01c9dfb57a191 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -32,6 +32,7 @@ import org.apache.avro.file.{DataFileReader, DataFileWriter} import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils +import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException} import org.apache.spark.TestUtils.assertExceptionMsg @@ -369,7 +370,7 @@ abstract class AvroSuite "", Seq()) } - assert(e.getMessage.contains("Cannot generate stable indentifier")) + assert(e.getMessage.contains("Cannot generate stable identifier")) } { val e = intercept[Exception] { @@ -380,7 +381,7 @@ abstract class AvroSuite "", Seq()) } - assert(e.getMessage.contains("Cannot generate stable indentifier")) + assert(e.getMessage.contains("Cannot generate stable identifier")) } // Two array types or two map types are not allowed in union. { @@ -433,6 +434,33 @@ abstract class AvroSuite } } + test("SPARK-47904: Test that field name case is preserved") { + checkUnionStableId( + List( + Schema.createEnum("myENUM", "", null, List[String]("E1", "e2").asJava), + Schema.createRecord("myRecord", "", null, false, + List[Schema.Field](new Schema.Field("f", Schema.createFixed("myField", "", null, 6))) + .asJava), + Schema.createRecord("myRecord2", "", null, false, + List[Schema.Field](new Schema.Field("F", Schema.create(Type.FLOAT))) + .asJava)), + "struct, " + + "member_myRecord2: struct>", + Seq()) + + { + val e = intercept[Exception] { + checkUnionStableId( + List( + Schema.createRecord("myRecord", "", null, false, List[Schema.Field]().asJava), + Schema.createRecord("myrecord", "", null, false, List[Schema.Field]().asJava)), + "", + Seq()) + } + assert(e.getMessage.contains("Cannot generate stable identifier")) + } + } + test("SPARK-27858 Union type: More than one non-null type") { Seq(true, false).foreach { isStableUnionMember => withTempDir { dir => @@ -814,6 +842,163 @@ abstract class AvroSuite } } + test("SPARK-43380: Fix Avro data type conversion" + + " of decimal type to avoid producing incorrect results") { + withTempPath { path => + val confKey = SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA.key + sql("SELECT 13.1234567890 a").write.format("avro").save(path.toString) + // With the flag disabled, we will throw an exception if there is a mismatch + withSQLConf(confKey -> "false") { + val e = intercept[SparkException] { + spark.read.schema("a DECIMAL(4, 3)").format("avro").load(path.toString).collect() + } + ExceptionUtils.getRootCause(e) match { + case ex: AnalysisException => + checkError( + exception = ex, + errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + parameters = Map("avroPath" -> "field 'a'", + "sqlPath" -> "field 'a'", + "avroType" -> "decimal\\(12,10\\)", + "sqlType" -> "\"DECIMAL\\(4,3\\)\""), + matchPVals = true + ) + case other => + fail(s"Received unexpected exception", other) + } + } + // The following used to work, so it should still work with the flag enabled + checkAnswer( + spark.read.schema("a DECIMAL(5, 3)").format("avro").load(path.toString), + Row(new java.math.BigDecimal("13.123")) + ) + withSQLConf(confKey -> "true") { + // With the flag enabled, we return a null silently, which isn't great + checkAnswer( + spark.read.schema("a DECIMAL(4, 3)").format("avro").load(path.toString), + Row(null) + ) + checkAnswer( + spark.read.schema("a DECIMAL(5, 3)").format("avro").load(path.toString), + Row(new java.math.BigDecimal("13.123")) + ) + } + } + } + + test("SPARK-43380: Fix Avro data type conversion" + + " of DayTimeIntervalType to avoid producing incorrect results") { + withTempPath { path => + val confKey = SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA.key + val schema = StructType(Array(StructField("a", DayTimeIntervalType(), false))) + val data = Seq(Row(java.time.Duration.ofDays(1).plusSeconds(1))) + + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.format("avro").save(path.getCanonicalPath) + + withSQLConf(confKey -> "false") { + Seq("DATE", "TIMESTAMP", "TIMESTAMP_NTZ").foreach { sqlType => + val e = intercept[SparkException] { + spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() + } + + ExceptionUtils.getRootCause(e) match { + case ex: AnalysisException => + checkError( + exception = ex, + errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + parameters = Map("avroPath" -> "field 'a'", + "sqlPath" -> "field 'a'", + "avroType" -> "interval day to second", + "sqlType" -> s""""$sqlType""""), + matchPVals = true + ) + case other => + fail(s"Received unexpected exception", other) + } + } + } + + withSQLConf(confKey -> "true") { + // Allow conversion and do not need to check result + spark.read.schema("a Date").format("avro").load(path.toString) + spark.read.schema("a timestamp").format("avro").load(path.toString) + spark.read.schema("a timestamp_ntz").format("avro").load(path.toString) + } + } + } + + test("SPARK-43380: Fix Avro data type conversion" + + " of YearMonthIntervalType to avoid producing incorrect results") { + withTempPath { path => + val confKey = SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA.key + val schema = StructType(Array(StructField("a", YearMonthIntervalType(), false))) + val data = Seq(Row(java.time.Period.of(1, 1, 0))) + + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.format("avro").save(path.getCanonicalPath) + + withSQLConf(confKey -> "false") { + Seq("DATE", "TIMESTAMP", "TIMESTAMP_NTZ").foreach { sqlType => + val e = intercept[SparkException] { + spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() + } + + ExceptionUtils.getRootCause(e) match { + case ex: AnalysisException => + checkError( + exception = ex, + errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + parameters = Map("avroPath" -> "field 'a'", + "sqlPath" -> "field 'a'", + "avroType" -> "interval year to month", + "sqlType" -> s""""$sqlType""""), + matchPVals = true + ) + case other => + fail(s"Received unexpected exception", other) + } + } + } + + withSQLConf(confKey -> "true") { + // Allow conversion and do not need to check result + spark.read.schema("a Date").format("avro").load(path.toString) + spark.read.schema("a timestamp").format("avro").load(path.toString) + spark.read.schema("a timestamp_ntz").format("avro").load(path.toString) + } + } + } + + Seq( + "time-millis", + "time-micros", + "timestamp-micros", + "timestamp-millis", + "local-timestamp-millis", + "local-timestamp-micros" + ).foreach { timeLogicalType => + test(s"converting $timeLogicalType type to long in avro") { + withTempPath { path => + val df = Seq(100L) + .toDF("dt") + val avroSchema = + s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "dt", "type": {"type": "long", "logicalType": "$timeLogicalType"}} + | ] + |}""".stripMargin + df.write.format("avro").option("avroSchema", avroSchema).save(path.getCanonicalPath) + checkAnswer( + spark.read.schema(s"dt long").format("avro").load(path.toString), + Row(100L)) + } + } + } + test("converting some specific sparkSQL types to avro") { withTempPath { tempDir => val testSchema = StructType(Seq( @@ -1979,7 +2164,7 @@ abstract class AvroSuite private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { val message = intercept[IncompatibleSchemaException] { - SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema)) + SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false) }.getMessage assert(message.contains("Found recursive reference in Avro schema")) diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 8c9d11f64eec8..87f6a589261cc 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../../../pom.xml @@ -50,10 +50,26 @@ spark-sketch_${scala.binary.version} ${project.version} + + + com.google.protobuf + protobuf-java + compile + com.google.guava guava ${connect.guava.version} + compile + + + com.google.guava + failureaccess + ${guava.failureaccess.version} + compile com.lihaoyi @@ -61,6 +77,16 @@ ${ammonite.version} provided + + commons-io + commons-io + test + + + org.apache.commons + commons-lang3 + test + org.scalacheck scalacheck_${scala.binary.version} @@ -85,59 +111,78 @@ maven-shade-plugin false + true + com.google.guava:* com.google.android:* com.google.api.grpc:* com.google.code.findbugs:* com.google.code.gson:* com.google.errorprone:* - com.google.guava:* com.google.j2objc:* com.google.protobuf:* + com.google.flatbuffers:* io.grpc:* io.netty:* io.perfmark:* + org.apache.arrow:* org.codehaus.mojo:* org.checkerframework:* org.apache.spark:spark-connect-common_${scala.binary.version} + org.apache.spark:spark-sql-api_${scala.binary.version} + + com.google.common + ${spark.shade.packageName}.connect.guava + + com.google.common.** + + io.grpc - ${spark.shade.packageName}.connect.client.io.grpc + ${spark.shade.packageName}.io.grpc io.grpc.** com.google - ${spark.shade.packageName}.connect.client.com.google + ${spark.shade.packageName}.com.google + + + com.google.common.** + io.netty - ${spark.shade.packageName}.connect.client.io.netty + ${spark.shade.packageName}.io.netty org.checkerframework - ${spark.shade.packageName}.connect.client.org.checkerframework + ${spark.shade.packageName}.org.checkerframework javax.annotation - ${spark.shade.packageName}.connect.client.javax.annotation + ${spark.shade.packageName}.javax.annotation io.perfmark - ${spark.shade.packageName}.connect.client.io.perfmark + ${spark.shade.packageName}.io.perfmark org.codehaus - ${spark.shade.packageName}.connect.client.org.codehaus + ${spark.shade.packageName}.org.codehaus + + + org.apache.arrow + ${spark.shade.packageName}.org.apache.arrow android.annotation - ${spark.shade.packageName}.connect.client.android.annotation + ${spark.shade.packageName}.android.annotation diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index bdaa4e28ba892..865596a669a09 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1291,7 +1291,7 @@ class Dataset[T] private[sql] ( val unpivot = builder.getUnpivotBuilder .setInput(plan.getRoot) .addAllIds(ids.toSeq.map(_.expr).asJava) - .setValueColumnName(variableColumnName) + .setVariableColumnName(variableColumnName) .setValueColumnName(valueColumnName) valuesOption.foreach { values => unpivot.getValuesBuilder diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 7882ea6401354..421f37b9e8a62 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -134,8 +134,6 @@ class SparkSession private[sql] ( } else { val hash = client.cacheLocalRelation(arrowData, encoder.schema.json) builder.getCachedLocalRelationBuilder - .setUserId(client.userId) - .setSessionId(client.sessionId) .setHash(hash) } } else { @@ -785,7 +783,10 @@ object SparkSession extends Logging { } class Builder() extends Logging { - private val builder = SparkConnectClient.builder() + // Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE + // by default, if it exists. The connection string can be overridden using + // the remote() function, as it takes precedence over the SPARK_REMOTE environment variable. + private val builder = SparkConnectClient.builder().loadFromEnvironment() private var client: SparkConnectClient = _ private[this] val options = new scala.collection.mutable.HashMap[String, String] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index fe992ae6740bf..8f55954a63f33 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -2624,7 +2624,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ln(e: Column): Column = log(e) + def ln(e: Column): Column = Column.fn("ln", e) /** * Computes the natural logarithm of the given value. @@ -2632,7 +2632,7 @@ object functions { * @group math_funcs * @since 3.4.0 */ - def log(e: Column): Column = Column.fn("log", e) + def log(e: Column): Column = ln(e) /** * Computes the natural logarithm of the given column. @@ -3477,7 +3477,7 @@ object functions { mode: Column, padding: Column, aad: Column): Column = - Column.fn("aes_encrypt", input, key, mode, padding, aad) + Column.fn("aes_decrypt", input, key, mode, padding, aad) /** * Returns a decrypted value of `input`. @@ -3489,7 +3489,7 @@ object functions { * @since 3.5.0 */ def aes_decrypt(input: Column, key: Column, mode: Column, padding: Column): Column = - Column.fn("aes_encrypt", input, key, mode, padding) + Column.fn("aes_decrypt", input, key, mode, padding) /** * Returns a decrypted value of `input`. @@ -3501,7 +3501,7 @@ object functions { * @since 3.5.0 */ def aes_decrypt(input: Column, key: Column, mode: Column): Column = - Column.fn("aes_encrypt", input, key, mode) + Column.fn("aes_decrypt", input, key, mode) /** * Returns a decrypted value of `input`. @@ -3513,7 +3513,7 @@ object functions { * @since 3.5.0 */ def aes_decrypt(input: Column, key: Column): Column = - Column.fn("aes_encrypt", input, key) + Column.fn("aes_decrypt", input, key) /** * This is a special version of `aes_decrypt` that performs the same operation, but returns a diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index df36b53791a81..feefd19000d1d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -767,6 +767,64 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert(joined2.schema.catalogString === "struct") } + test("SPARK-45509: ambiguous column reference") { + val session = spark + import session.implicits._ + val df1 = Seq(1 -> "a").toDF("i", "j") + val df1_filter = df1.filter(df1("i") > 0) + val df2 = Seq(2 -> "b").toDF("i", "y") + + checkSameResult( + Seq(Row(1)), + // df1("i") is not ambiguous, and it's still valid in the filtered df. + df1_filter.select(df1("i"))) + + val e1 = intercept[AnalysisException] { + // df1("i") is not ambiguous, but it's not valid in the projected df. + df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect() + } + assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT")) + + checkSameResult( + Seq(Row(1, "a")), + // All these column references are not ambiguous and are still valid after join. + df1.join(df2, df1("i") + 1 === df2("i")).sort(df1("i").desc).select(df1("i"), df1("j"))) + + val e2 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1, df1("i") === 1).collect() + } + assert(e2.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e3 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1).select(df1("i")).collect() + } + assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e4 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides (df1_filter contains df1). + df1.join(df1_filter, df1("i") === 1).collect() + } + assert(e4.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter does not exist in the join left side. + df1.join(df1_filter, df1_filter("i") === 1).select(df1_filter("j"))) + + val e5 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both sides of the first join. + df1.join(df1_filter, df1_filter("i") === 1).join(df2, df1("i") === 1).collect() + } + assert(e5.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter only appears once. + df1.join(df1_filter).join(df2, df1_filter("i") === 1).select(df1_filter("j"))) + } + test("broadcast join") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { val left = spark.range(100).select(col("id"), rand(10).as("a")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index c76dc724828e5..e9c2f0c457508 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -108,7 +108,8 @@ class SparkSessionE2ESuite extends RemoteSparkSession { assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } - test("interrupt tag") { + // TODO(SPARK-48139): Re-enable `SparkSessionE2ESuite.interrupt tag` + ignore("interrupt tag") { val session = spark import session.implicits._ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 5bb8cbf3543b0..9d61b4d56e1ed 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -362,4 +362,20 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { val output = runCommandsInShell(input) assertContains("noException: Boolean = true", output) } + + test("broadcast works with REPL generated code") { + val input = + """ + |val add1 = udf((i: Long) => i + 1) + |val tableA = spark.range(2).alias("a") + |val tableB = broadcast(spark.range(2).select(add1(col("id")).alias("id"))).alias("b") + |tableA.join(tableB). + | where(col("a.id")===col("b.id")). + | select(col("a.id").alias("a_id"), col("b.id").alias("b_id")). + | collect(). + | mkString("[", ", ", "]") + |""".stripMargin + val output = runCommandsInShell(input) + assertContains("""String = "[[1,1]]"""", output) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 80e245ec78b7d..89acc2c60ac21 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -86,6 +86,24 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(response.getSessionId === "abc123") } + private def withEnvs(pairs: (String, String)*)(f: => Unit): Unit = { + val readonlyEnv = System.getenv() + val field = readonlyEnv.getClass.getDeclaredField("m") + field.setAccessible(true) + val modifiableEnv = field.get(readonlyEnv).asInstanceOf[java.util.Map[String, String]] + try { + for ((k, v) <- pairs) { + assert(!modifiableEnv.containsKey(k)) + modifiableEnv.put(k, v) + } + f + } finally { + for ((k, _) <- pairs) { + modifiableEnv.remove(k) + } + } + } + test("Test connection") { testClientConnection() { testPort => SparkConnectClient.builder().port(testPort).build() } } @@ -112,6 +130,49 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } } + test("SparkSession create with SPARK_REMOTE") { + startDummyServer(0) + + withEnvs("SPARK_REMOTE" -> s"sc://localhost:${server.getPort}") { + val session = SparkSession.builder().create() + val df = session.range(10) + df.analyze // Trigger RPC + assert(df.plan === service.getAndClearLatestInputPlan()) + + val session2 = SparkSession.builder().create() + assert(session != session2) + } + } + + test("SparkSession getOrCreate with SPARK_REMOTE") { + startDummyServer(0) + + withEnvs("SPARK_REMOTE" -> s"sc://localhost:${server.getPort}") { + val session = SparkSession.builder().getOrCreate() + + val df = session.range(10) + df.analyze // Trigger RPC + assert(df.plan === service.getAndClearLatestInputPlan()) + + val session2 = SparkSession.builder().getOrCreate() + assert(session === session2) + } + } + + test("Builder.remote takes precedence over SPARK_REMOTE") { + startDummyServer(0) + val incorrectUrl = s"sc://localhost:${server.getPort + 1}" + + withEnvs("SPARK_REMOTE" -> incorrectUrl) { + val session = + SparkSession.builder().remote(s"sc://localhost:${server.getPort}").getOrCreate() + + val df = session.range(10) + df.analyze // Trigger RPC + assert(df.plan === service.getAndClearLatestInputPlan()) + } + } + test("SparkSession initialisation with connection string") { startDummyServer(0) client = SparkConnectClient diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml index c78c5445e5073..994179fd99ac8 100644 --- a/connector/connect/common/pom.xml +++ b/connector/connect/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../../pom.xml @@ -47,18 +47,6 @@ com.google.protobuf protobuf-java - - com.google.guava - guava - ${connect.guava.version} - compile - - - com.google.guava - failureaccess - ${guava.failureaccess.version} - compile - io.grpc grpc-netty @@ -152,6 +140,27 @@ + + org.apache.maven.plugins + maven-shade-plugin + + false + + + org.spark-project.spark:unused + org.apache.tomcat:annotations-api + + + + + + package + + shade + + + + diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 8001b3cbcfaa4..f7f1315ede0f8 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -400,11 +400,11 @@ message LocalRelation { // A local relation that has been cached already. message CachedLocalRelation { - // (Required) An identifier of the user which created the local relation - string userId = 1; - - // (Required) An identifier of the Spark SQL session in which the user created the local relation. - string sessionId = 2; + // `userId` and `sessionId` fields are deleted since the server must always use the active + // session/user rather than arbitrary values provided by the client. It is never valid to access + // a local relation from a different session/user. + reserved 1, 2; + reserved "userId", "sessionId"; // (Required) A sha-256 hash of the serialized local relation in proto, see LocalRelation. string hash = 3; diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala index 891e50ed6e7bd..810158b2ac8b3 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala @@ -27,6 +27,20 @@ private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { } } +private[sql] abstract class WrappedCloseableIterator[E] extends CloseableIterator[E] { + + def innerIterator: Iterator[E] + + override def next(): E = innerIterator.next() + + override def hasNext(): Boolean = innerIterator.hasNext + + override def close(): Unit = innerIterator match { + case it: CloseableIterator[E] => it.close() + case _ => // nothing + } +} + private[sql] object CloseableIterator { /** @@ -34,13 +48,9 @@ private[sql] object CloseableIterator { */ def apply[T](iterator: Iterator[T]): CloseableIterator[T] = iterator match { case closeable: CloseableIterator[T] => closeable - case _ => - new CloseableIterator[T] { - override def next(): T = iterator.next() - - override def hasNext(): Boolean = iterator.hasNext - - override def close() = { /* empty */ } + case iter => + new WrappedCloseableIterator[T] { + override def innerIterator: Iterator[T] = iter } } } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index 73ff01e223f29..80edcfa8be16a 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -22,7 +22,7 @@ import io.grpc.ManagedChannel import org.apache.spark.connect.proto._ -private[client] class CustomSparkConnectBlockingStub( +private[connect] class CustomSparkConnectBlockingStub( channel: ManagedChannel, retryPolicy: GrpcRetryHandler.RetryPolicy) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 9bf7de33da8a7..57a629264be10 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.client import java.util.UUID +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import io.grpc.{ManagedChannel, StatusRuntimeException} @@ -50,7 +51,7 @@ class ExecutePlanResponseReattachableIterator( request: proto.ExecutePlanRequest, channel: ManagedChannel, retryPolicy: GrpcRetryHandler.RetryPolicy) - extends CloseableIterator[proto.ExecutePlanResponse] + extends WrappedCloseableIterator[proto.ExecutePlanResponse] with Logging { val operationId = if (request.hasOperationId) { @@ -86,14 +87,25 @@ class ExecutePlanResponseReattachableIterator( // True after ResultComplete message was seen in the stream. // Server will always send this message at the end of the stream, if the underlying iterator // finishes without producing one, another iterator needs to be reattached. - private var resultComplete: Boolean = false + // Visible for testing. + private[connect] var resultComplete: Boolean = false // Initial iterator comes from ExecutePlan request. // Note: This is not retried, because no error would ever be thrown here, and GRPC will only // throw error on first iter.hasNext() or iter.next() - private var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = + // Visible for testing. + private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = Some(rawBlockingStub.executePlan(initialRequest)) + override def innerIterator: Iterator[proto.ExecutePlanResponse] = iter match { + case Some(it) => it.asScala + case None => + // The iterator is only unset for short moments while retry exception is thrown. + // It should only happen in the middle of internal processing. Since this iterator is not + // thread safe, no-one should be accessing it at this moment. + throw new IllegalStateException("innerIterator unset") + } + override def next(): proto.ExecutePlanResponse = synchronized { // hasNext will trigger reattach in case the stream completed without resultComplete if (!hasNext()) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index c430485bd4184..fe9f6dc2b4a9a 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -43,7 +43,10 @@ private[client] object GrpcExceptionConverter extends JsonUtils { } def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = { - new CloseableIterator[T] { + new WrappedCloseableIterator[T] { + + override def innerIterator: Iterator[T] = iter + override def hasNext: Boolean = { convert { iter.hasNext diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index 8791530607c3a..3c0b750fd46e7 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -48,11 +48,13 @@ private[sql] class GrpcRetryHandler( * The type of the response. */ class RetryIterator[T, U](request: T, call: T => CloseableIterator[U]) - extends CloseableIterator[U] { + extends WrappedCloseableIterator[U] { private var opened = false // we only retry if it fails on first call when using the iterator private var iter = call(request) + override def innerIterator: Iterator[U] = iter + private def retryIter[V](f: Iterator[U] => V) = { if (!opened) { opened = true diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index cd54966ccf54d..9429578598712 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -332,15 +332,17 @@ object ArrowDeserializers { val constructor = methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) val lookup = createFieldLookup(vectors) - val setters = fields.map { field => - val vector = lookup(field.name) - val deserializer = deserializerFor(field.enc, vector, timeZoneId) - val setter = methodLookup.findVirtual( - tag.runtimeClass, - field.writeMethod.get, - MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) - (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) - } + val setters = fields + .filter(_.writeMethod.isDefined) + .map { field => + val vector = lookup(field.name) + val deserializer = deserializerFor(field.enc, vector, timeZoneId) + val setter = methodLookup.findVirtual( + tag.runtimeClass, + field.writeMethod.get, + MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) + (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) + } new StructFieldSerializer[Any](struct) { def value(i: Int): Any = { val instance = constructor.invoke() diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain index 44084a8e60fb0..31e03b79eb987 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, GCM, DEFAULT, , )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, GCM, DEFAULT, )#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain index 29ccf0c1c833f..fc572e8fe7c67 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, DEFAULT, , )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, DEFAULT, )#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain index 5591363426ab5..c6c693013dd0a 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, , )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, )#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain index 0e8d4df71b38e..97bb528b84b3f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, g, )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain index bc3c6e4bb2bcf..27058b201c01a 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_base64.explain @@ -1,2 +1,2 @@ -Project [base64(cast(g#0 as binary)) AS base64(g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.Base64, StringType, encode, cast(g#0 as binary), true, BinaryType, BooleanType, true, false, true) AS base64(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ln.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ln.explain index d3c3743b1ef40..66b782ac8170d 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ln.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ln.explain @@ -1,2 +1,2 @@ -Project [LOG(E(), b#0) AS LOG(E(), b)#0] +Project [ln(b#0) AS ln(b)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_log.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_log.explain index d3c3743b1ef40..66b782ac8170d 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_log.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_log.explain @@ -1,2 +1,2 @@ -Project [LOG(E(), b#0) AS LOG(E(), b)#0] +Project [ln(b#0) AS ln(b)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain index f61fc30a3a529..053937d84ec8f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain @@ -1,2 +1,2 @@ -Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, #0, value#0] +Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, name#0, value#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain index b5742d976dee9..5a953f792cd35 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain @@ -1,2 +1,2 @@ -Expand [[a#0, id, id#0L]], [a#0, #0, value#0L] +Expand [[a#0, id, id#0L]], [a#0, name#0, value#0L] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain index 8d1749ee74c5a..2b2ba19d0c3db 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain @@ -1,2 +1,2 @@ -Expand [[id#0L, a, cast(a#0 as double)], [id#0L, b, b#0]], [id#0L, #0, value#0] +Expand [[id#0L, a, cast(a#0 as double)], [id#0L, b, b#0]], [id#0L, name#0, value#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain index f61fc30a3a529..053937d84ec8f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain @@ -1,2 +1,2 @@ -Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, #0, value#0] +Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, name#0, value#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.json b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.json index 06469d4840547..4204a44b44ce0 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "aes_encrypt", + "functionName": "aes_decrypt", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "g" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.proto.bin index c7a70b51707f3..f635e1fc689b1 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.json b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.json index 7eb9b4ed8b4ed..9c630e1253494 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "aes_encrypt", + "functionName": "aes_decrypt", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "g" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.proto.bin index ecd81ae44fcbd..41d024cdb7eed 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.json b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.json index 59a6a5e35fd42..8f5be474ab4b3 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "aes_encrypt", + "functionName": "aes_decrypt", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "g" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.proto.bin index 9de01ddc5ea69..cd6764581f2ca 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.json b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.json index a87ec1b7f4d29..9381042b71886 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "aes_encrypt", + "functionName": "aes_decrypt", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "g" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.proto.bin index 13da507fe6ff4..ca789f04ce1d4 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_aes_decrypt_with_mode_padding_aad.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ln.json b/connector/connect/common/src/test/resources/query-tests/queries/function_ln.json index 1b2d0ed0b1447..ababbc52d088d 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_ln.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_ln.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "log", + "functionName": "ln", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ln.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_ln.proto.bin index 548fb480dd27e..ecb87a1fc4102 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_ln.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_ln.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_log.json b/connector/connect/common/src/test/resources/query-tests/queries/function_log.json index 1b2d0ed0b1447..ababbc52d088d 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_log.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_log.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "log", + "functionName": "ln", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_log.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_log.proto.bin index 548fb480dd27e..ecb87a1fc4102 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_log.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_log.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.json b/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.json index 12db0a5abe368..a17da06b925b9 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.json @@ -20,6 +20,7 @@ "unparsedIdentifier": "a" } }], + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin index 23a6aa1289a99..eebb7ad6df8e2 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/melt_values.json b/connector/connect/common/src/test/resources/query-tests/queries/melt_values.json index e2a004f46e781..a8142ee3a8461 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/melt_values.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/melt_values.json @@ -23,6 +23,7 @@ } }] }, + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin index e021e1110def5..35829fc62dae9 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json index 9f550c0319147..96b76443b6790 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json @@ -16,6 +16,7 @@ "unparsedIdentifier": "id" } }], + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin index ac3bad8bd04ed..b700190a9f667 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.json b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.json index 92bc19d195c6e..6c31afb04e741 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.json @@ -27,6 +27,7 @@ } }] }, + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin index 7f717cb23517b..a1cd388fd8a46 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin differ diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml index 10deea435d2bd..801c28319ee84 100644 --- a/connector/connect/server/pom.xml +++ b/connector/connect/server/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../../pom.xml @@ -168,6 +168,7 @@ com.google.guava failureaccess ${guava.failureaccess.version} + compile com.google.protobuf diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 7b8b05ce11a82..253ac38f9cf9e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -133,7 +133,7 @@ object Connect { "With any value greater than 0, the last sent response will always be buffered.") .version("3.5.0") .bytesConf(ByteUnit.BYTE) - .createWithDefaultString("1m") + .createWithDefaultString("10m") val CONNECT_EXTENSIONS_RELATION_CLASSES = buildStaticConf("spark.connect.extensions.relation.classes") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala index 6b8fcde1156ed..c3c33a85d6517 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -47,6 +47,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( private var interrupted = false + // Time at which this sender should finish if the response stream is not finished by then. + private var deadlineTimeMillis = Long.MaxValue + // Signal to wake up when grpcCallObserver.isReady() private val grpcCallObserverReadySignal = new Object @@ -65,6 +68,12 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( executionObserver.notifyAll() } + // For testing + private[connect] def setDeadline(deadlineMs: Long) = executionObserver.synchronized { + deadlineTimeMillis = deadlineMs + executionObserver.notifyAll() + } + def run(lastConsumedStreamIndex: Long): Unit = { if (executeHolder.reattachable) { // In reattachable execution we use setOnReadyHandler and grpcCallObserver.isReady to control @@ -150,7 +159,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( var finished = false // Time at which this sender should finish if the response stream is not finished by then. - val deadlineTimeMillis = if (!executeHolder.reattachable) { + deadlineTimeMillis = if (!executeHolder.reattachable) { Long.MaxValue } else { val confSize = @@ -232,8 +241,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( assert(finished == false) } else { // If it wasn't sent, time deadline must have been reached before stream became available, - // will exit in the enxt loop iterattion. - assert(deadlineLimitReached) + // or it was intterupted. Will exit in the next loop iterattion. + assert(deadlineLimitReached || interrupted) } } else if (streamFinished) { // Stream is finished and all responses have been sent @@ -301,7 +310,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( val sleepStart = System.nanoTime() var sleepEnd = 0L // Conditions for exiting the inner loop - // 1. was detached + // 1. was interrupted // 2. grpcCallObserver is ready to send more data // 3. time deadline is reached while (!interrupted && diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala index d9db07fd228ed..df0fb3ac3a592 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala @@ -73,11 +73,16 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: /** The index of the last response produced by execution. */ private var lastProducedIndex: Long = 0 // first response will have index 1 + // For testing + private[connect] var releasedUntilIndex: Long = 0 + /** * Highest response index that was consumed. Keeps track of it to decide which responses needs * to be cached, and to assert that all responses are consumed. + * + * Visible for testing. */ - private var highestConsumedIndex: Long = 0 + private[connect] var highestConsumedIndex: Long = 0 /** * Consumer that waits for available responses. There can be only one at a time, @see @@ -284,6 +289,7 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: responses.remove(i) i -= 1 } + releasedUntilIndex = index } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 62083d4892f78..d503dde3d18c1 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.execution +import scala.concurrent.{ExecutionContext, Promise} +import scala.util.Try import scala.util.control.NonFatal import com.google.protobuf.Message @@ -29,7 +31,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag} import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * This class launches the actual execution in an execution thread. The execution pushes the @@ -37,10 +39,12 @@ import org.apache.spark.util.Utils */ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { + private val promise: Promise[Unit] = Promise[Unit]() + // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, // forwarding of thread locals needs to be taken into account. - private var executionThread: Thread = new ExecutionThread() + private val executionThread: Thread = new ExecutionThread(promise) private var interrupted: Boolean = false @@ -53,9 +57,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends executionThread.start() } - /** Joins the background execution thread after it is finished. */ - def join(): Unit = { - executionThread.join() + /** + * Register a callback that gets executed after completion/interruption of the execution + */ + private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { + promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) } /** @@ -222,10 +228,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .build() } - private class ExecutionThread + private class ExecutionThread(onCompletionPromise: Promise[Unit]) extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { override def run(): Unit = { - execute() + try { + execute() + onCompletionPromise.success(()) + } catch { + case NonFatal(e) => + onCompletionPromise.failure(e) + } } } } + +private[connect] object ExecuteThreadRunner { + private implicit val namedExecutionContext: ExecutionContext = ExecutionContext + .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 2abbacc5a9b7f..709e0811e5de2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -164,7 +164,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => transformCachedRemoteRelation(rel.getCachedRemoteRelation) case proto.Relation.RelTypeCase.COLLECT_METRICS => - transformCollectMetrics(rel.getCollectMetrics) + transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId) case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") @@ -674,8 +674,6 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { transformTypedCoGroupMap(rel, commonUdf) case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => - val pythonUdf = transformPythonUDF(commonUdf) - val inputCols = rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) @@ -690,6 +688,10 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .ofRows(session, transformRelation(rel.getOther)) .groupBy(otherCols: _*) + val pythonUdf = createUserDefinedPythonFunction(commonUdf) + .builder(input.df.logicalPlan.output ++ other.df.logicalPlan.output) + .asInstanceOf[PythonUDF] + input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan case _ => @@ -970,7 +972,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private def transformCachedLocalRelation(rel: proto.CachedLocalRelation): LogicalPlan = { val blockManager = session.sparkContext.env.blockManager - val blockId = CacheId(rel.getUserId, rel.getSessionId, rel.getHash) + val blockId = CacheId(sessionHolder.userId, sessionHolder.sessionId, rel.getHash) val bytes = blockManager.getLocalBytes(blockId) bytes .map { blockData => @@ -1054,12 +1056,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { numPartitionsOpt) } - private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan = { + private def transformCollectMetrics(rel: proto.CollectMetrics, planId: Long): LogicalPlan = { val metrics = rel.getMetricsList.asScala.toSeq.map { expr => Column(transformExpression(expr)) } - CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput)) + CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput), planId) } private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { @@ -1587,17 +1589,23 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private def transformPythonFuncExpression( fun: proto.CommonInlineUserDefinedFunction): Expression = { + createUserDefinedPythonFunction(fun) + .builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq) match { + case udaf: PythonUDAF => udaf.toAggregateExpression() + case other => other + } + } + + private def createUserDefinedPythonFunction( + fun: proto.CommonInlineUserDefinedFunction): UserDefinedPythonFunction = { val udf = fun.getPythonUdf + val function = transformPythonFunction(udf) UserDefinedPythonFunction( name = fun.getFunctionName, - func = transformPythonFunction(udf), + func = function, dataType = transformDataType(udf.getOutputType), pythonEvalType = udf.getEvalType, udfDeterministic = fun.getDeterministic) - .builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq) match { - case udaf: PythonUDAF => udaf.toAggregateExpression() - case other => other - } } private def transformPythonFunction(fun: proto.PythonUDF): SimplePythonFunction = { @@ -2584,15 +2592,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = { - val udf = fun.getPythonUdf - val function = transformPythonFunction(udf) - val udpf = UserDefinedPythonFunction( - name = fun.getFunctionName, - func = function, - dataType = transformDataType(udf.getOutputType), - pythonEvalType = udf.getEvalType, - udfDeterministic = fun.getDeterministic) - + val udpf = createUserDefinedPythonFunction(fun) session.udf.registerPython(fun.getFunctionName, udpf) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index bce0713339228..0e4f344da901c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -114,6 +114,9 @@ private[connect] class ExecuteHolder( : mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] = new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]() + /** For testing. Whether the async completion callback is called. */ + @volatile private[connect] var completionCallbackCalled: Boolean = false + /** * Start the execution. The execution is started in a background thread in ExecuteThreadRunner. * Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the @@ -125,13 +128,6 @@ private[connect] class ExecuteHolder( runner.start() } - /** - * Wait for the execution thread to finish and join it. - */ - def join(): Unit = { - runner.join() - } - /** * Attach an ExecuteGrpcResponseSender that will consume responses from the query and send them * out on the Grpc response stream. The sender will start from the start of the response stream. @@ -183,6 +179,16 @@ private[connect] class ExecuteHolder( } } + // For testing. + private[connect] def setGrpcResponseSendersDeadline(deadlineMs: Long) = synchronized { + grpcResponseSenders.foreach(_.setDeadline(deadlineMs)) + } + + // For testing + private[connect] def interruptGrpcResponseSenders() = synchronized { + grpcResponseSenders.foreach(_.interrupt()) + } + /** * For a short period in ExecutePlan after creation and until runGrpcResponseSender is called, * there is no attached response sender, but yet we start with lastAttachedRpcTime = None, so we @@ -224,8 +230,15 @@ private[connect] class ExecuteHolder( if (closedTime.isEmpty) { // interrupt execution, if still running. runner.interrupt() - // wait for execution to finish, to make sure no more results get pushed to responseObserver - runner.join() + // Do not wait for the execution to finish, clean up resources immediately. + runner.processOnCompletion { _ => + completionCallbackCalled = true + // The execution may not immediately get interrupted, clean up any remaining resources when + // it does. + responseObserver.removeAll() + // post closed to UI + eventsManager.postClosed() + } // interrupt any attached grpcResponseSenders grpcResponseSenders.foreach(_.interrupt()) // if there were still any grpcResponseSenders, register detach time @@ -235,8 +248,6 @@ private[connect] class ExecuteHolder( } // remove all cached responses from observer responseObserver.removeAll() - // post closed to UI - eventsManager.postClosed() closedTime = Some(System.currentTimeMillis()) } } @@ -274,7 +285,7 @@ private[connect] class ExecuteHolder( object ExecuteJobTag { private val prefix = "SparkConnect_OperationTag" - def apply(sessionId: String, userId: String, operationId: String): String = { + def apply(userId: String, sessionId: String, operationId: String): String = { s"${prefix}_" + s"User_${userId}_" + s"Session_${sessionId}_" + diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 1cef02d7e3466..218819d114c12 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -197,7 +197,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] def cacheDataFrameById(dfId: String, df: DataFrame): Unit = { if (dataFrameCache.putIfAbsent(dfId, df) != null) { - SparkException.internalError(s"A dataframe is already associated with id $dfId") + throw SparkException.internalError(s"A dataframe is already associated with id $dfId") } } @@ -221,7 +221,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] def cacheListenerById(id: String, listener: StreamingQueryListener): Unit = { if (listenerCache.putIfAbsent(id, listener) != null) { - SparkException.internalError(s"A listener is already associated with id $id") + throw SparkException.internalError(s"A listener is already associated with id $id") } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index ce1f6c93f6cfe..21f59bdd68ea5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -71,15 +71,14 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // The latter is to prevent double execution when a client retries execution, thinking it // never reached the server, but in fact it did, and already got removed as abandoned. if (executions.get(executeHolder.key).isDefined) { - if (getAbandonedTombstone(executeHolder.key).isDefined) { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", - messageParameters = Map("handle" -> executeHolder.operationId)) - } else { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", - messageParameters = Map("handle" -> executeHolder.operationId)) - } + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", + messageParameters = Map("handle" -> executeHolder.operationId)) + } + if (getAbandonedTombstone(executeHolder.key).isDefined) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", + messageParameters = Map("handle" -> executeHolder.operationId)) } sessionHolder.addExecuteHolder(executeHolder) executions.put(executeHolder.key, executeHolder) @@ -141,12 +140,17 @@ private[connect] class SparkConnectExecutionManager() extends Logging { abandonedTombstones.asMap.asScala.values.toBuffer.toSeq } - private[service] def shutdown(): Unit = executionsLock.synchronized { + private[connect] def shutdown(): Unit = executionsLock.synchronized { scheduledExecutor.foreach { executor => executor.shutdown() executor.awaitTermination(1, TimeUnit.MINUTES) } scheduledExecutor = None + executions.clear() + abandonedTombstones.invalidateAll() + if (!lastExecutionTime.isDefined) { + lastExecutionTime = Some(System.currentTimeMillis()) + } } /** @@ -188,7 +192,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { executions.values.foreach { executeHolder => executeHolder.lastAttachedRpcTime match { case Some(detached) => - if (detached + timeout < nowMs) { + if (detached + timeout <= nowMs) { toRemove += executeHolder } case _ => // execution is active @@ -206,4 +210,18 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") } + + // For testing. + private[connect] def setAllRPCsDeadline(deadlineMs: Long) = executionsLock.synchronized { + executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs)) + } + + // For testing. + private[connect] def interruptAllRPCs() = executionsLock.synchronized { + executions.values.foreach(_.interruptGrpcResponseSenders()) + } + + private[connect] def listExecuteHolders = executionsLock.synchronized { + executions.values.toBuffer.toSeq + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala new file mode 100644 index 0000000000000..eddd1c6be72b1 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import java.util.UUID + +import org.scalatest.concurrent.{Eventually, TimeLimits} +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, WrappedCloseableIterator} +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.dsl.MockRemoteSession +import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.connect.service.{ExecuteHolder, SparkConnectService} +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Base class and utilities for a test suite that starts and tests the real SparkConnectService + * with a real SparkConnectClient, communicating over RPC, but both in-process. + */ +trait SparkConnectServerTest extends SharedSparkSession { + + // Server port + val serverPort: Int = + ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) + + val eventuallyTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + // Other suites using mocks leave a mess in the global executionManager, + // shut it down so that it's cleared before starting server. + SparkConnectService.executionManager.shutdown() + // Start the real service. + withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { + SparkConnectService.start(spark.sparkContext) + } + // register udf directly on the server, we're not testing client UDFs here... + val serverSession = + SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session + serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) + } + + override def afterAll(): Unit = { + SparkConnectService.stop() + super.afterAll() + } + + override def beforeEach(): Unit = { + super.beforeEach() + clearAllExecutions() + } + + override def afterEach(): Unit = { + clearAllExecutions() + super.afterEach() + } + + protected def clearAllExecutions(): Unit = { + SparkConnectService.executionManager.listExecuteHolders.foreach(_.close()) + SparkConnectService.executionManager.periodicMaintenance(0) + assertNoActiveExecutions() + } + + protected val defaultSessionId = UUID.randomUUID.toString() + protected val defaultUserId = UUID.randomUUID.toString() + + // We don't have the real SparkSession/Dataset api available, + // so use mock for generating simple query plans. + protected val dsl = new MockRemoteSession() + + protected val userContext = proto.UserContext + .newBuilder() + .setUserId(defaultUserId) + .build() + + protected def buildExecutePlanRequest( + plan: proto.Plan, + sessionId: String = defaultSessionId, + operationId: String = UUID.randomUUID.toString) = { + proto.ExecutePlanRequest + .newBuilder() + .setUserContext(userContext) + .setSessionId(sessionId) + .setOperationId(operationId) + .setPlan(plan) + .addRequestOptions( + proto.ExecutePlanRequest.RequestOption + .newBuilder() + .setReattachOptions(proto.ReattachOptions.newBuilder().setReattachable(true).build()) + .build()) + .build() + } + + protected def buildReattachExecuteRequest(operationId: String, responseId: Option[String]) = { + val req = proto.ReattachExecuteRequest + .newBuilder() + .setUserContext(userContext) + .setSessionId(defaultSessionId) + .setOperationId(operationId) + + if (responseId.isDefined) { + req.setLastResponseId(responseId.get) + } + + req.build() + } + + protected def buildPlan(query: String) = { + proto.Plan.newBuilder().setRoot(dsl.sql(query)).build() + } + + protected def getReattachableIterator( + stubIterator: CloseableIterator[proto.ExecutePlanResponse]) = { + // This depends on the wrapping in CustomSparkConnectBlockingStub.executePlanReattachable: + // GrpcExceptionConverter.convertIterator + stubIterator + .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] + // ExecutePlanResponseReattachableIterator + .innerIterator + .asInstanceOf[ExecutePlanResponseReattachableIterator] + } + + protected def assertNoActiveRpcs(): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // nothing running, good + case Right(executions) => + // all rpc detached. + assert( + executions.forall(_.lastAttachedRpcTime.isDefined), + s"Expected no RPCs, but got $executions") + } + } + + protected def assertEventuallyNoActiveRpcs(): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertNoActiveRpcs() + } + } + + protected def assertNoActiveExecutions(): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // cleaned up + case Right(executions) => fail(s"Expected empty, but got $executions") + } + } + + protected def assertEventuallyNoActiveExecutions(): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertNoActiveExecutions() + } + } + + protected def assertExecutionReleased(operationId: String): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // cleaned up + case Right(executions) => assert(!executions.exists(_.operationId == operationId)) + } + } + + protected def assertEventuallyExecutionReleased(operationId: String): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertExecutionReleased(operationId) + } + } + + // Get ExecutionHolder, assuming that only one execution is active + protected def getExecutionHolder: ExecuteHolder = { + val executions = SparkConnectService.executionManager.listExecuteHolders + assert(executions.length == 1) + executions.head + } + + protected def withClient(f: SparkConnectClient => Unit): Unit = { + val client = SparkConnectClient + .builder() + .port(serverPort) + .sessionId(defaultSessionId) + .userId(defaultUserId) + .enableReattachableExecute() + .build() + try f(client) + finally { + client.shutdown() + } + } + + protected def withRawBlockingStub( + f: proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub => Unit): Unit = { + val conf = SparkConnectClient.Configuration(port = serverPort) + val channel = conf.createChannel() + val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + try f(bstub) + finally { + channel.shutdownNow() + } + } + + protected def withCustomBlockingStub( + retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy())( + f: CustomSparkConnectBlockingStub => Unit): Unit = { + val conf = SparkConnectClient.Configuration(port = serverPort) + val channel = conf.createChannel() + val bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy) + try f(bstub) + finally { + channel.shutdownNow() + } + } + + protected def runQuery(plan: proto.Plan, queryTimeout: Span, iterSleep: Long): Unit = { + withClient { client => + TimeLimits.failAfter(queryTimeout) { + val iter = client.execute(plan) + var operationId: Option[String] = None + var r: proto.ExecutePlanResponse = null + val reattachableIter = getReattachableIterator(iter) + while (iter.hasNext) { + r = iter.next() + operationId match { + case None => operationId = Some(r.getOperationId) + case Some(id) => assert(r.getOperationId == id) + } + if (iterSleep > 0) { + Thread.sleep(iterSleep) + } + } + // Check that last response had ResultComplete indicator + assert(r != null) + assert(r.hasResultComplete) + // ... that client sent ReleaseExecute based on it + assert(reattachableIter.resultComplete) + // ... and that the server released the execution. + assert(operationId.isDefined) + assertEventuallyExecutionReleased(operationId.get) + } + } + } + + protected def runQuery(query: String, queryTimeout: Span, iterSleep: Long = 0): Unit = { + val plan = buildPlan(query) + runQuery(plan, queryTimeout, iterSleep) + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala new file mode 100644 index 0000000000000..06cd1a5666b66 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.execution + +import java.util.UUID + +import io.grpc.StatusRuntimeException +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.sql.connect.SparkConnectServerTest +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +class ReattachableExecuteSuite extends SparkConnectServerTest { + + // Tests assume that this query will result in at least a couple ExecutePlanResponses on the + // stream. If this is no longer the case because of changes in how much is returned in a single + // ExecutePlanResponse, it may need to be adjusted. + val MEDIUM_RESULTS_QUERY = "select * from range(10000000)" + + test("reattach after initial RPC ends") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + + // open the iterator + iter.next() + // expire all RPCs on server + SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + assertEventuallyNoActiveRpcs() + // iterator should reattach + // (but not necessarily at first next, as there might have been messages buffered client side) + while (iter.hasNext && (reattachableIter.innerIterator eq initialInnerIter)) { + iter.next() + } + assert( + reattachableIter.innerIterator ne initialInnerIter + ) // reattach changed the inner iter + } + } + + test("raw interrupted RPC results in INVALID_CURSOR.DISCONNECTED error") { + withRawBlockingStub { stub => + val iter = stub.executePlan(buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY))) + iter.next() // open the iterator + // interrupt all RPCs on server + SparkConnectService.executionManager.interruptAllRPCs() + assertEventuallyNoActiveRpcs() + val e = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + } + } + + test("raw new RPC interrupts previous RPC with INVALID_CURSOR.DISCONNECTED error") { + // Raw stub does not have retries, auto reattach etc. + withRawBlockingStub { stub => + val operationId = UUID.randomUUID().toString + val iter = stub.executePlan( + buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId)) + iter.next() // open the iterator + + // send reattach + val iter2 = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + iter2.next() // open the iterator + + // should result in INVALID_CURSOR.DISCONNECTED error on the original iterator + val e = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + + // send another reattach + val iter3 = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + assert(iter3.hasNext) + iter3.next() // open the iterator + + // should result in INVALID_CURSOR.DISCONNECTED error on the previous reattach iterator + val e2 = intercept[StatusRuntimeException] { + while (iter2.hasNext) iter2.next() + } + assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + } + } + + test("client INVALID_CURSOR.DISCONNECTED error is retried when rpc sender gets interrupted") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + // open the iterator + iter.next() + + // interrupt all RPCs on server + SparkConnectService.executionManager.interruptAllRPCs() + assertEventuallyNoActiveRpcs() + + // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error + iter.next() + // iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + + test("client INVALID_CURSOR.DISCONNECTED error is retried when other RPC preempts this one") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + // open the iterator + val response = iter.next() + + // Send another Reattach request, it should preempt this request with an + // INVALID_CURSOR.DISCONNECTED error. + withRawBlockingStub { stub => + val reattachIter = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response.getResponseId))) + assert(reattachIter.hasNext) + } + + // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error + iter.next() + // iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + + test("abandoned query gets INVALID_HANDLE.OPERATION_ABANDONED error") { + withClient { client => + val plan = buildPlan("select * from range(100000)") + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val operationId = getReattachableIterator(iter).operationId + // open the iterator + iter.next() + // disconnect and remove on server + SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + assertEventuallyNoActiveRpcs() + SparkConnectService.executionManager.periodicMaintenance(0) + assertNoActiveExecutions() + // check that it throws abandoned error + val e = intercept[SparkException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + // check that afterwards, new operation can't be created with the same operationId. + withCustomBlockingStub() { stub => + val executePlanReq = buildExecutePlanRequest(plan, operationId = operationId) + + val iterNonReattachable = stub.executePlan(executePlanReq) + val eNonReattachable = intercept[SparkException] { + iterNonReattachable.hasNext + } + assert(eNonReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + val iterReattachable = stub.executePlanReattachable(executePlanReq) + val eReattachable = intercept[SparkException] { + iterReattachable.hasNext + } + assert(eReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + } + } + } + + test("client releases responses directly after consuming them") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + assert(iter.hasNext) // open iterator + val execution = getExecutionHolder + assert(execution.responseObserver.releasedUntilIndex == 0) + + // get two responses, check on the server that ReleaseExecute releases them afterwards + val response1 = iter.next() + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex == 1) + } + + val response2 = iter.next() + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex == 2) + } + + withRawBlockingStub { stub => + // Reattach after response1 should fail with INVALID_CURSOR.POSITION_NOT_AVAILABLE + val reattach1 = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response1.getResponseId))) + val e = intercept[StatusRuntimeException] { + reattach1.hasNext() + } + assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE")) + + // Reattach after response2 should work + val reattach2 = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response2.getResponseId))) + val response3 = reattach2.next() + val response4 = reattach2.next() + val response5 = reattach2.next() + + // The original client iterator will handle the INVALID_CURSOR.DISCONNECTED error, + // and reconnect back. Since the raw iterator was not releasing responses, client iterator + // should be able to continue where it left off (server shouldn't have released yet) + assert(execution.responseObserver.releasedUntilIndex == 2) + assert(iter.hasNext) + + val r3 = iter.next() + assert(r3.getResponseId == response3.getResponseId) + val r4 = iter.next() + assert(r4.getResponseId == response4.getResponseId) + val r5 = iter.next() + assert(r5.getResponseId == response5.getResponseId) + // inner iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + } + + test("server releases responses automatically when client moves ahead") { + withRawBlockingStub { stub => + val operationId = UUID.randomUUID().toString + val iter = stub.executePlan( + buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId)) + var lastSeenResponse: String = null + val serverRetryBuffer = SparkEnv.get.conf + .get(Connect.CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE) + .toLong + + iter.hasNext // open iterator + val execution = getExecutionHolder + + // after consuming enough from the iterator, server should automatically start releasing + var lastSeenIndex = 0 + var totalSizeSeen = 0 + while (iter.hasNext && totalSizeSeen <= 1.1 * serverRetryBuffer) { + val r = iter.next() + lastSeenResponse = r.getResponseId() + totalSizeSeen += r.getSerializedSize + lastSeenIndex += 1 + } + assert(iter.hasNext) + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex > 0) + } + + // Reattach from the beginning is not available. + val reattach = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + val e = intercept[StatusRuntimeException] { + reattach.hasNext() + } + assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE")) + + // Original iterator got disconnected by the reattach and gets INVALID_CURSOR.DISCONNECTED + val e2 = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + + Eventually.eventually(timeout(eventuallyTimeout)) { + // Even though we didn't consume more from the iterator, the server thinks that + // it sent more, because GRPC stream onNext() can push into internal GRPC buffer without + // client picking it up. + assert(execution.responseObserver.highestConsumedIndex > lastSeenIndex) + } + // but CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE is big enough that the last + // response we've seen is still in range + assert(execution.responseObserver.releasedUntilIndex < lastSeenIndex) + + // and a new reattach can continue after what there. + val reattach2 = + stub.reattachExecute(buildReattachExecuteRequest(operationId, Some(lastSeenResponse))) + assert(reattach2.hasNext) + while (reattach2.hasNext) reattach2.next() + } + } + + // A few integration tests with large results. + // They should run significantly faster than the LARGE_QUERY_TIMEOUT + // - big query (4 seconds, 871 milliseconds) + // - big query and slow client (7 seconds, 288 milliseconds) + // - big query with frequent reattach (1 second, 527 milliseconds) + // - big query with frequent reattach and slow client (7 seconds, 365 milliseconds) + // - long sleeping query (10 seconds, 805 milliseconds) + + // intentionally smaller than CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION, + // so that reattach deadline doesn't "unstuck" if something got stuck. + val LARGE_QUERY_TIMEOUT = 100.seconds + + val LARGE_RESULTS_QUERY = s"select id, " + + (1 to 20).map(i => s"cast(id as string) c$i").mkString(", ") + + s" from range(1000000)" + + test("big query") { + // regular query with large results + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + + test("big query and slow client") { + // regular query with large results, but client is slow so sender will need to control flow + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + + test("big query with frequent reattach") { + // will reattach every 100kB + withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key, "100k")) { + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("big query with frequent reattach and slow client") { + // will reattach every 100kB, and in addition the client is slow, + // so sender will need to control flow + withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key, "100k")) { + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("long sleeping query") { + // query will be sleeping and not returning results, while having multiple reattach + withSparkEnvConfs( + (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, "1s")) { + runQuery("select sleep(10000) as s", 30.seconds) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("Async cleanup callback gets called after the execution is closed") { + withClient { client => + val query1 = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + // just creating the iterator is lazy, trigger query1 to be sent. + query1.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 1) + } + val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head + // Close execution + SparkConnectService.executionManager.removeExecuteHolder(executeHolder1.key) + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 0) + } + // Check the async execute cleanup get called + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(executeHolder1.completionCallbackCalled) + } + } + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 0c12bf5e625a9..8bc4de8351248 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Observation, import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto @@ -1067,7 +1067,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { // Compares proto plan with LogicalPlan. private def comparePlans(connectPlan: proto.Relation, sparkPlan: LogicalPlan): Unit = { + def normalizeDataframeId(plan: LogicalPlan): LogicalPlan = plan transform { + case cm: CollectMetrics => cm.copy(dataframeId = 0) + } val connectAnalyzed = analyzePlan(transform(connectPlan)) - comparePlans(connectAnalyzed, sparkPlan, false) + comparePlans(normalizeDataframeId(connectAnalyzed), normalizeDataframeId(sparkPlan), false) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 90c9d13def616..06508bfc6a7c2 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader import org.mockito.Mockito.when import org.scalatest.Tag +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkContext, SparkEnv} @@ -879,8 +881,11 @@ class SparkConnectServiceSuite assert(executeHolder.eventsManager.hasError.isDefined) } def onCompleted(producedRowCount: Option[Long] = None): Unit = { - assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) + // The eventsManager is closed asynchronously + Eventually.eventually(timeout(1.seconds)) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } } def onCanceled(): Unit = { assert(executeHolder.eventsManager.hasCanceled.contains(true)) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala similarity index 100% rename from connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala rename to connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala index 3b75c37b2aa00..c9c110dd1e626 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala @@ -37,7 +37,7 @@ class SparkConnectServerListenerSuite private var kvstore: ElementTrackingStore = _ - private val jobTag = ExecuteJobTag("sessionId", "userId", "operationId") + private val jobTag = ExecuteJobTag("userId", "sessionId", "operationId") after { if (kvstore != null) { @@ -174,7 +174,7 @@ class SparkConnectServerListenerSuite SparkListenerJobStart(0, System.currentTimeMillis(), Nil, createProperties)) listener.onOtherEvent( SparkListenerConnectSessionClosed("sessionId", "userId", System.currentTimeMillis())) - val exec = statusStore.getExecution(ExecuteJobTag("sessionId", "userId", "operationId")) + val exec = statusStore.getExecution(ExecuteJobTag("userId", "sessionId", "operationId")) assert(exec.isDefined) assert(exec.get.jobId === Seq("0")) assert(exec.get.sqlExecId === Set("0")) @@ -190,7 +190,7 @@ class SparkConnectServerListenerSuite listener.onOtherEvent(SparkListenerConnectSessionClosed(unknownSession, "userId", 0)) listener.onOtherEvent( SparkListenerConnectOperationStarted( - ExecuteJobTag("sessionId", "userId", "operationId"), + ExecuteJobTag("userId", "sessionId", "operationId"), "operationId", System.currentTimeMillis(), unknownSession, diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml index 87df8a9ff5bea..19377b36a612f 100644 --- a/connector/docker-integration-tests/pom.xml +++ b/connector/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml @@ -46,22 +46,6 @@ - - com.spotify - docker-client - test - shaded - - - org.apache.httpcomponents - httpclient - test - - - org.apache.httpcomponents - httpcore - test - com.google.guava @@ -112,14 +96,6 @@ hadoop-minikdc test - - - org.glassfish.jersey.bundles.repackaged - jersey-guava - 2.25.1 - test - org.mariadb.jdbc mariadb-java-client @@ -167,5 +143,15 @@ mysql-connector-j test + + com.github.docker-java + docker-java + test + + + com.github.docker-java + docker-java-transport-zerodep + test + diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala index 9b518d61d252f..66e2afbb6effd 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.security.PrivilegedExceptionAction import java.sql.Connection import javax.security.auth.login.Configuration -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} +import com.github.dockerjava.api.model.{AccessMode, Bind, ContainerConfig, HostConfig, Volume} import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS import org.scalatest.time.SpanSugar._ @@ -66,14 +66,15 @@ class DB2KrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { } override def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = { + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = { copyExecutableResource("db2_krb_setup.sh", initDbDir, replaceIp) - hostConfigBuilder.appendBinds( - HostConfig.Bind.from(initDbDir.getAbsolutePath) - .to("/var/custom").readOnly(true).build() - ) + val newBind = new Bind( + initDbDir.getAbsolutePath, + new Volume("/var/custom"), + AccessMode.ro) + hostConfigBuilder.withBinds(hostConfigBuilder.getBinds :+ newBind: _*) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index 40e8cbb6546b5..837382239514a 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -20,14 +20,18 @@ package org.apache.spark.sql.jdbc import java.net.ServerSocket import java.sql.{Connection, DriverManager} import java.util.Properties +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import com.spotify.docker.client._ -import com.spotify.docker.client.DockerClient.{ListContainersParam, LogsParam} -import com.spotify.docker.client.exceptions.ImageNotFoundException -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} +import com.github.dockerjava.api.DockerClient +import com.github.dockerjava.api.async.{ResultCallback, ResultCallbackTemplate} +import com.github.dockerjava.api.command.CreateContainerResponse +import com.github.dockerjava.api.exception.NotFoundException +import com.github.dockerjava.api.model._ +import com.github.dockerjava.core.{DefaultDockerClientConfig, DockerClientImpl} +import com.github.dockerjava.zerodep.ZerodepDockerHttpClient import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ @@ -88,8 +92,8 @@ abstract class DatabaseOnDocker { * Optional step before container starts */ def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = {} + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = {} } abstract class DockerJDBCIntegrationSuite @@ -97,7 +101,7 @@ abstract class DockerJDBCIntegrationSuite protected val dockerIp = DockerUtils.getDockerIp() val db: DatabaseOnDocker - val connectionTimeout = timeout(5.minutes) + val connectionTimeout = timeout(10.minutes) val keepContainer = sys.props.getOrElse("spark.test.docker.keepContainer", "false").toBoolean val removePulledImage = @@ -111,56 +115,75 @@ abstract class DockerJDBCIntegrationSuite sock.close() port } - private var containerId: String = _ + private var container: CreateContainerResponse = _ private var pulled: Boolean = false protected var jdbcUrl: String = _ override def beforeAll(): Unit = runIfTestsEnabled(s"Prepare for ${this.getClass.getName}") { super.beforeAll() try { - docker = DefaultDockerClient.fromEnv.build() + val config = DefaultDockerClientConfig.createDefaultConfigBuilder.build + val httpClient = new ZerodepDockerHttpClient.Builder() + .dockerHost(config.getDockerHost) + .sslConfig(config.getSSLConfig) + .build() + docker = DockerClientImpl.getInstance(config, httpClient) // Check that Docker is actually up try { - docker.ping() + docker.pingCmd().exec() } catch { case NonFatal(e) => log.error("Exception while connecting to Docker. Check whether Docker is running.") throw e } - // Ensure that the Docker image is installed: try { - docker.inspectImage(db.imageName) + // Ensure that the Docker image is installed: + docker.inspectImageCmd(db.imageName).exec() } catch { - case e: ImageNotFoundException => + case e: NotFoundException => log.warn(s"Docker image ${db.imageName} not found; pulling image from registry") - docker.pull(db.imageName) + docker.pullImageCmd(db.imageName) + .start() + .awaitCompletion(connectionTimeout.value.toSeconds, TimeUnit.SECONDS) pulled = true } - val hostConfigBuilder = HostConfig.builder() - .privileged(db.privileged) - .networkMode("bridge") - .ipcMode(if (db.usesIpc) "host" else "") - .portBindings( - Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) - // Create the database container: - val containerConfigBuilder = ContainerConfig.builder() - .image(db.imageName) - .networkDisabled(false) - .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) - .exposedPorts(s"${db.jdbcPort}/tcp") - if (db.getEntryPoint.isDefined) { - containerConfigBuilder.entrypoint(db.getEntryPoint.get) - } - if (db.getStartupProcessName.isDefined) { - containerConfigBuilder.cmd(db.getStartupProcessName.get) + + docker.pullImageCmd(db.imageName) + .start() + .awaitCompletion(connectionTimeout.value.toSeconds, TimeUnit.SECONDS) + + val hostConfig = HostConfig + .newHostConfig() + .withNetworkMode("bridge") + .withPrivileged(db.privileged) + .withPortBindings(PortBinding.parse(s"$externalPort:${db.jdbcPort}")) + + if (db.usesIpc) { + hostConfig.withIpcMode("host") } - db.beforeContainerStart(hostConfigBuilder, containerConfigBuilder) - containerConfigBuilder.hostConfig(hostConfigBuilder.build()) - val config = containerConfigBuilder.build() + + val containerConfig = new ContainerConfig() + + db.beforeContainerStart(hostConfig, containerConfig) + // Create the database container: - containerId = docker.createContainer(config).id + val createContainerCmd = docker.createContainerCmd(db.imageName) + .withHostConfig(hostConfig) + .withExposedPorts(ExposedPort.tcp(db.jdbcPort)) + .withEnv(db.env.map { case (k, v) => s"$k=$v" }.toList.asJava) + .withNetworkDisabled(false) + + + db.getEntryPoint.foreach(ep => createContainerCmd.withEntrypoint(ep)) + db.getStartupProcessName.foreach(n => createContainerCmd.withCmd(n)) + + container = createContainerCmd.exec() // Start the container and wait until the database can accept JDBC connections: - docker.startContainer(containerId) + docker.startContainerCmd(container.getId).exec() + eventually(connectionTimeout, interval(1.second)) { + val response = docker.inspectContainerCmd(container.getId).exec() + assert(response.getState.getRunning) + } jdbcUrl = db.getJdbcUrl(dockerIp, externalPort) var conn: Connection = null eventually(connectionTimeout, interval(1.second)) { @@ -174,6 +197,7 @@ abstract class DockerJDBCIntegrationSuite } } catch { case NonFatal(e) => + logError(s"Failed to initialize Docker container for ${this.getClass.getName}", e) try { afterAll() } finally { @@ -206,36 +230,35 @@ abstract class DockerJDBCIntegrationSuite def dataPreparation(connection: Connection): Unit private def cleanupContainer(): Unit = { - if (docker != null && containerId != null && !keepContainer) { + if (docker != null && container != null && !keepContainer) { try { - docker.killContainer(containerId) + docker.killContainerCmd(container.getId).exec() } catch { case NonFatal(e) => - val exitContainerIds = - docker.listContainers(ListContainersParam.withStatusExited()).asScala.map(_.id()) - if (exitContainerIds.contains(containerId)) { - logWarning(s"Container $containerId already stopped") - } else { - logWarning(s"Could not stop container $containerId", e) - } + val response = docker.inspectContainerCmd(container.getId).exec() + logWarning(s"Container $container already stopped") + val status = Option(response).map(_.getState.getStatus).getOrElse("unknown") + logWarning(s"Could not stop container $container at stage '$status'", e) } finally { logContainerOutput() - docker.removeContainer(containerId) + docker.removeContainerCmd(container.getId).exec() if (removePulledImage && pulled) { - docker.removeImage(db.imageName) + docker.removeImageCmd(db.imageName).exec() } } } } private def logContainerOutput(): Unit = { - val logStream = docker.logs(containerId, LogsParam.stdout(), LogsParam.stderr()) - try { - logInfo("\n\n===== CONTAINER LOGS FOR container Id: " + containerId + " =====") - logInfo(logStream.readFully()) - logInfo("\n\n===== END OF CONTAINER LOGS FOR container Id: " + containerId + " =====") - } finally { - logStream.close() - } + logInfo("\n\n===== CONTAINER LOGS FOR container Id: " + container + " =====") + docker.logContainerCmd(container.getId) + .withStdOut(true) + .withStdErr(true) + .withFollowStream(true) + .withSince(0).exec( + new ResultCallbackTemplate[ResultCallback[Frame], Frame] { + override def onNext(f: Frame): Unit = logInfo(f.toString) + }) + logInfo("\n\n===== END OF CONTAINER LOGS FOR container Id: " + container + " =====") } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala index 873d5ad1ee43b..49c9e3dba0d7f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import javax.security.auth.login.Configuration -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} +import com.github.dockerjava.api.model.{AccessMode, Bind, ContainerConfig, HostConfig, Volume} import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider import org.apache.spark.tags.DockerTest @@ -52,17 +52,17 @@ class MariaDBKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { Some("/docker-entrypoint/mariadb_docker_entrypoint.sh") override def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = { + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = { copyExecutableResource("mariadb_docker_entrypoint.sh", entryPointDir, replaceIp) copyExecutableResource("mariadb_krb_setup.sh", initDbDir, replaceIp) - hostConfigBuilder.appendBinds( - HostConfig.Bind.from(entryPointDir.getAbsolutePath) - .to("/docker-entrypoint").readOnly(true).build(), - HostConfig.Bind.from(initDbDir.getAbsolutePath) - .to("/docker-entrypoint-initdb.d").readOnly(true).build() - ) + val binds = + Seq(entryPointDir -> "/docker-entrypoint", initDbDir -> "/docker-entrypoint-initdb.d") + .map { case (from, to) => + new Bind(from.getAbsolutePath, new Volume(to), AccessMode.ro) + } + hostConfigBuilder.withBinds(hostConfigBuilder.getBinds ++ binds: _*) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala new file mode 100644 index 0000000000000..b351b2ad1ec7d --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +class MsSQLServerDatabaseOnDocker extends DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", + "mcr.microsoft.com/mssql/server:2022-CU12-GDR1-ubuntu-22.04") + override val env = Map( + "SA_PASSWORD" -> "Sapass123", + "ACCEPT_EULA" -> "Y" + ) + override val usesIpc = false + override val jdbcPort: Int = 1433 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index f2614f46bc3f6..443000050a476 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -38,19 +38,7 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") - override val env = Map( - "SA_PASSWORD" -> "Sapass123", - "ACCEPT_EULA" -> "Y" - ) - override val usesIpc = false - override val jdbcPort: Int = 1433 - - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - } + override val db = new MsSQLServerDatabaseOnDocker override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE TABLE tbl (x INT, y VARCHAR (50))").executeUpdate() diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index dc3acb66ff1f4..cefbe41b64bd3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -43,7 +43,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val usesIpc = false override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&disableMariaDbDriver" } override def dataPreparation(conn: Connection): Unit = { @@ -56,10 +56,14 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " - + "dbl DOUBLE)").executeUpdate() + + "dbl DOUBLE, tiny TINYINT, u_tiny TINYINT UNSIGNED)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " - + "42.75, 1.0000000000000002)").executeUpdate() + + "42.75, 1.0000000000000002, -128, 255)").executeUpdate() + + conn.prepareStatement("INSERT INTO numbers VALUES (null, null, " + + "null, null, null, null, null, null, null, null, null)").executeUpdate() conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + "yr YEAR)").executeUpdate() @@ -74,6 +78,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { "'jumps', 'over', 'the', 'lazy', 'dog', '{\"status\": \"merrily\"}')").executeUpdate() } + def testConnection(): Unit = { + val conn = getConnection() + try { + assert(conn.getClass.getName === "com.mysql.cj.jdbc.ConnectionImpl") + } finally { + conn.close() + } + } + + test("SPARK-47537: ensure use the right jdbc driver") { + testConnection() + } + test("Basic test") { val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) val rows = df.collect() @@ -87,9 +104,9 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { test("Numeric types") { val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) val rows = df.collect() - assert(rows.length == 1) + assert(rows.length == 2) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 9) + assert(types.length == 11) assert(types(0).equals("class java.lang.Boolean")) assert(types(1).equals("class java.lang.Long")) assert(types(2).equals("class java.lang.Integer")) @@ -99,6 +116,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(6).equals("class java.math.BigDecimal")) assert(types(7).equals("class java.lang.Double")) assert(types(8).equals("class java.lang.Double")) + assert(types(9).equals("class java.lang.Byte")) + assert(types(10).equals("class java.lang.Short")) assert(rows(0).getBoolean(0) == false) assert(rows(0).getLong(1) == 0x225) assert(rows(0).getInt(2) == 17) @@ -109,6 +128,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getAs[BigDecimal](6).equals(bd)) assert(rows(0).getDouble(7) == 42.75) assert(rows(0).getDouble(8) == 1.0000000000000002) + assert(rows(0).getByte(9) == 0x80.toByte) + assert(rows(0).getShort(10) == 0xff.toShort) } test("Date types") { @@ -194,4 +215,50 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { """.stripMargin.replaceAll("\n", " ")) assert(sql("select x, y from queryOption").collect.toSet == expectedResult) } + + test("SPARK-47666: Check nulls for result set getters") { + val nulls = spark.read.jdbc(jdbcUrl, "numbers", new Properties).tail(1).head + assert(nulls === Row(null, null, null, null, null, null, null, null, null, null, null)) + } + + test("SPARK-44638: Char/Varchar in Custom Schema") { + val df = spark.read.option("url", jdbcUrl) + .option("query", "SELECT c, d from strings") + .option("customSchema", "c CHAR(10), d VARCHAR(10)") + .format("jdbc") + .load() + assert(df.head === Row("brown ", "fox")) + } +} + +/** + * To run this test suite for a specific version (e.g., mysql:8.3.0): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:8.3.0 + * ./build/sbt -Pdocker-integration-tests + * "docker-integration-tests/testOnly *MySQLOverMariaConnectorIntegrationSuite" + * }}} + */ +@DockerTest +class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite { + + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + } + + override def testConnection(): Unit = { + val conn = getConnection() + try { + assert(conn.getClass.getName === "org.mariadb.jdbc.MariaDbConnection") + } finally { + conn.close() + } + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 483f6087c81d2..70afad781ca25 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -173,8 +173,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark } - // SPARK-43049: Use CLOB instead of VARCHAR(255) for StringType for Oracle jdbc-am"" - test("SPARK-12941: String datatypes to be mapped to CLOB in Oracle") { + test("SPARK-12941: String datatypes to be mapped to VARCHAR(255) in Oracle") { // create a sample dataframe with string type val df1 = sparkContext.parallelize(Seq(("foo"))).toDF("x") // write the dataframe to the oracle table tbl diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 90d6f6ae2fbfc..23fbf39db3be0 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -20,19 +20,19 @@ package org.apache.spark.sql.jdbc import java.math.{BigDecimal => JBigDecimal} import java.sql.{Connection, Date, Timestamp} import java.text.SimpleDateFormat -import java.time.{LocalDateTime, ZoneOffset} +import java.time.LocalDateTime import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType} +import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} @@ -40,7 +40,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) @@ -445,9 +445,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row.length == 2) val infinity = row(0).getAs[Timestamp]("timestamp_column") val negativeInfinity = row(1).getAs[Timestamp]("timestamp_column") - val minTimeStamp = LocalDateTime.of(1, 1, 1, 0, 0, 0).toEpochSecond(ZoneOffset.UTC) - val maxTimestamp = LocalDateTime.of(9999, 12, 31, 23, 59, 59).toEpochSecond(ZoneOffset.UTC) - + val minTimeStamp = -62135596800000L + val maxTimestamp = 253402300799999L assert(infinity.getTime == maxTimestamp) assert(negativeInfinity.getTime == minTimeStamp) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index 4debe24754de3..1dcf101b394a4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql.jdbc import javax.security.auth.login.Configuration -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} +import com.github.dockerjava.api.model.{AccessMode, Bind, ContainerConfig, HostConfig, Volume} import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests "testOnly *PostgresKrbIntegrationSuite" * }}} */ @@ -37,7 +37,7 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "postgres.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) @@ -48,14 +48,14 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { s"jdbc:postgresql://$ip:$port/postgres?user=$principal&gsslib=gssapi" override def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = { + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = { copyExecutableResource("postgres_krb_setup.sh", initDbDir, replaceIp) - - hostConfigBuilder.appendBinds( - HostConfig.Bind.from(initDbDir.getAbsolutePath) - .to("/docker-entrypoint-initdb.d").readOnly(true).build() - ) + val newBind = new Bind( + initDbDir.getAbsolutePath, + new Volume("/docker-entrypoint-initdb.d"), + AccessMode.ro) + hostConfigBuilder.withBinds(hostConfigBuilder.getBinds :+ newBind: _*) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 661b1277e9f03..5bcc8afefb1dd 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -80,16 +80,24 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { connection.prepareStatement( "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)") .executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE DOUBLE") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", DoubleType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", DoubleType, true, defaultMetadata(DoubleType)) assert(t.schema === expectedSchema) // Update column type from DOUBLE to STRING val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE VARCHAR(10)" @@ -112,7 +120,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { sql(s"CREATE TABLE $tbl (ID INT)" + s" TBLPROPERTIES('CCSID'='UNICODE')") val t = spark.table(tbl) - val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala index 72edfc9f1bf1c..60345257f2dc4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala @@ -38,6 +38,25 @@ abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite { .executeUpdate() connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)") .executeUpdate() + + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_quote''_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_quote_not_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_percent%_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_percent_not_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_underscore_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_underscorenot_present')") + .executeUpdate() } def tablePreparation(connection: Connection): Unit diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index fc93f5cba4c03..de8fcf1a4a787 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.DatabaseOnDocker +import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -60,19 +60,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD "scan with aggregate push-down: REGR_SXY without DISTINCT") override val catalogName: String = "mssql" - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") - override val env = Map( - "SA_PASSWORD" -> "Sapass123", - "ACCEPT_EULA" -> "Y" - ) - override val usesIpc = false - override val jdbcPort: Int = 1433 - - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - } + override val db = new MsSQLServerDatabaseOnDocker override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName) @@ -86,6 +74,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD connection.prepareStatement( "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") .executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def notSupportsTableComment: Boolean = true @@ -93,11 +87,13 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update column type from STRING to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -125,4 +121,20 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD }, errorClass = "_LEGACY_ERROR_TEMP_2271") } + + test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") { + val df1 = sql("SELECT name FROM " + + s"$catalogName.employee WHERE ((name LIKE 'am%') = (name LIKE '%y'))") + assert(df1.collect().length == 4) + + val df2 = sql("SELECT name FROM " + + s"$catalogName.employee " + + "WHERE ((name NOT LIKE 'am%') = (name NOT LIKE '%y'))") + assert(df2.collect().length == 4) + + val df3 = sql("SELECT name FROM " + + s"$catalogName.employee " + + "WHERE (dept > 1 AND ((name LIKE 'am%') = (name LIKE '%y')))") + assert(df3.collect().length == 3) + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala index b0a2d37e465ac..de0ae5d59716b 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -21,7 +21,7 @@ import java.sql.Connection import scala.collection.JavaConverters._ -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, MsSQLServerDatabaseOnDocker} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest @@ -35,20 +35,7 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MsSqlServerNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") - override val env = Map( - "SA_PASSWORD" -> "Sapass123", - "ACCEPT_EULA" -> "Y" - ) - override val usesIpc = false - override val jdbcPort: Int = 1433 - - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - } - + override val db = new MsSQLServerDatabaseOnDocker val map = new CaseInsensitiveStringMap( Map("url" -> db.getJdbcUrl(dockerIp, externalPort), "driver" -> "com.microsoft.sqlserver.jdbc.SQLServerDriver").asJava) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 5e340f135c85d..faf9f14b260d4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -68,8 +68,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:mysql://$ip:$port/" + - s"mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true&useSSL=false" + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + + "&useSSL=false&disableMariaDbDriver" } override def sparkConf: SparkConf = super.sparkConf @@ -88,16 +88,24 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest connection.prepareStatement( "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," + " bonus DOUBLE)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col LONGTEXT + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update column type from STRING to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -145,7 +153,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest sql(s"CREATE TABLE $tbl (ID INT)" + s" TBLPROPERTIES('ENGINE'='InnoDB', 'DEFAULT CHARACTER SET'='utf8')") val t = spark.table(tbl) - val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) } @@ -164,3 +173,32 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest } } } + +/** + * To run this test suite for a specific version (e.g., mysql:8.3.0): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:8.3.0 + * ./build/sbt -Pdocker-integration-tests + * "docker-integration-tests/testOnly *MySQLOverMariaConnectorIntegrationSuite" + * }}} + */ +@DockerTest +class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite { + override def defaultMetadata(dataType: DataType = StringType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", true) + .build() + + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + + "&useSSL=false" + } +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index d58146fecdf42..8b889f8509f56 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -45,8 +45,8 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:mysql://$ip:$port/" + - s"mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true&useSSL=false" + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + + "&useSSL=false&disableMariaDbDriver" } val map = new CaseInsensitiveStringMap( diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 5124199328ce2..002091b6a0d80 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -22,8 +22,9 @@ import java.util.Locale import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkRuntimeException} import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ @@ -86,6 +87,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes s"jdbc:oracle:thin:system/$oracle_password@//$ip:$port/xe" } + override def defaultMetadata(dataType: DataType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType]) + .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, "varchar(255)") + .build() + override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort)) @@ -99,16 +106,24 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes connection.prepareStatement( "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + " bonus BINARY_DOUBLE)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", DecimalType(10, 0), true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", DecimalType(10, 0), true, super.defaultMetadata(DecimalType(10, 0))) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE LONG") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", DecimalType(19, 0), true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", DecimalType(19, 0), true, super.defaultMetadata(DecimalType(19, 0))) assert(t.schema === expectedSchema) // Update column type from LONG to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -129,12 +144,17 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) - test("SPARK-43049: Use CLOB instead of VARCHAR(255) for StringType for Oracle JDBC") { + test("SPARK-46478: Revert SPARK-43049 to use varchar(255) for string") { val tableName = catalogName + ".t1" withTable(tableName) { sql(s"CREATE TABLE $tableName(c1 string)") - sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')") - assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 256) + checkError( + exception = intercept[SparkRuntimeException] { + sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')") + }, + errorClass = "EXCEED_LIMIT_LENGTH", + parameters = Map("limit" -> "255") + ) } } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 85e85f8bf3803..7fef3ccd6b3f6 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2) * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresIntegrationSuite" * }}} */ @@ -38,7 +38,7 @@ import org.apache.spark.tags.DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) @@ -59,16 +59,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT connection.prepareStatement( "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + " bonus double precision)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update column type from STRING to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -91,7 +99,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT sql(s"CREATE TABLE $tbl (ID INT)" + s" TBLPROPERTIES('TABLESPACE'='pg_default')") val t = spark.table(tbl) - val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index cf7266e67e325..b725fc8967514 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -26,16 +26,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresNamespaceSuite" * }}} */ @DockerTest class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index b5f5b0e5f20bd..a0f337912c859 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -49,18 +49,21 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def notSupportsTableComment: Boolean = false - val defaultMetadata = new MetadataBuilder().putLong("scale", 0).build() + def defaultMetadata(dataType: DataType = StringType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", dataType.isInstanceOf[NumericType]) + .build() def testUpdateColumnNullability(tbl: String): Unit = { sql(s"CREATE TABLE $catalogName.alt_table (ID STRING NOT NULL)") var t = spark.table(s"$catalogName.alt_table") // nullable is true in the expectedSchema because Spark always sets nullable to true // regardless of the JDBC metadata https://github.com/apache/spark/pull/18445 - var expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + var expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $catalogName.alt_table ALTER COLUMN ID DROP NOT NULL") t = spark.table(s"$catalogName.alt_table") - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update nullability of not existing column val msg = intercept[AnalysisException] { @@ -72,8 +75,9 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def testRenameColumn(tbl: String): Unit = { sql(s"ALTER TABLE $tbl RENAME COLUMN ID TO RENAMED") val t = spark.table(s"$tbl") - val expectedSchema = new StructType().add("RENAMED", StringType, true, defaultMetadata) - .add("ID1", StringType, true, defaultMetadata).add("ID2", StringType, true, defaultMetadata) + val expectedSchema = new StructType().add("RENAMED", StringType, true, defaultMetadata()) + .add("ID1", StringType, true, defaultMetadata()) + .add("ID2", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) } @@ -83,16 +87,19 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu withTable(s"$catalogName.alt_table") { sql(s"CREATE TABLE $catalogName.alt_table (ID STRING)") var t = spark.table(s"$catalogName.alt_table") - var expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C1 STRING, C2 STRING)") t = spark.table(s"$catalogName.alt_table") - expectedSchema = expectedSchema.add("C1", StringType, true, defaultMetadata) - .add("C2", StringType, true, defaultMetadata) + expectedSchema = expectedSchema + .add("C1", StringType, true, defaultMetadata()) + .add("C2", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 STRING)") t = spark.table(s"$catalogName.alt_table") - expectedSchema = expectedSchema.add("C3", StringType, true, defaultMetadata) + expectedSchema = expectedSchema + .add("C3", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Add already existing column checkError( @@ -125,7 +132,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu sql(s"ALTER TABLE $catalogName.alt_table DROP COLUMN C1") sql(s"ALTER TABLE $catalogName.alt_table DROP COLUMN c3") val t = spark.table(s"$catalogName.alt_table") - val expectedSchema = new StructType().add("C2", StringType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("C2", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Drop not existing column val msg = intercept[AnalysisException] { @@ -350,6 +358,235 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(scan.schema.names.sameElements(Seq(col))) } + test("SPARK-48172: Test CONTAINS") { + val df1 = spark.sql( + s""" + |SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin) + df1.explain("formatted") + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'percent%')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'character') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test ENDSWITH") { + val df1 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'quote\\'_present')""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'percent%_present')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'underscore_present')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'present') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test STARTSWITH") { + val df1 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character_quote\\'')""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character_percent%')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character_underscore_')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test LIKE") { + // this one should map to contains + val df1 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%character%' + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + + // map to startsWith + // this one should map to contains + val df5 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character_quote\\'%'""".stripMargin) + val rows5 = df5.collect() + assert(rows5.length === 1) + assert(rows5(0).getString(0) === "special_character_quote'_present") + + val df6 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character_percent\\%%'""".stripMargin) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "special_character_percent%_present") + + val df7 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character_underscore\\_%'""".stripMargin) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "special_character_underscore_present") + + val df8 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character%' + |ORDER BY pattern_testing_col""".stripMargin) + val rows8 = df8.collect() + assert(rows8.length === 6) + assert(rows8(0).getString(0) === "special_character_percent%_present") + assert(rows8(1).getString(0) === "special_character_percent_not_present") + assert(rows8(2).getString(0) === "special_character_quote'_present") + assert(rows8(3).getString(0) === "special_character_quote_not_present") + assert(rows8(4).getString(0) === "special_character_underscore_present") + assert(rows8(5).getString(0) === "special_character_underscorenot_present") + // map to endsWith + // this one should map to contains + val df9 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "special_character_quote'_present") + + val df10 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin) + val rows10 = df10.collect() + assert(rows10.length === 1) + assert(rows10(0).getString(0) === "special_character_percent%_present") + + val df11 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%underscore\\_present'""".stripMargin) + val rows11 = df11.collect() + assert(rows11.length === 1) + assert(rows11(0).getString(0) === "special_character_underscore_present") + + val df12 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%present' ORDER BY pattern_testing_col""".stripMargin) + val rows12 = df12.collect() + assert(rows12.length === 6) + assert(rows12(0).getString(0) === "special_character_percent%_present") + assert(rows12(1).getString(0) === "special_character_percent_not_present") + assert(rows12(2).getString(0) === "special_character_quote'_present") + assert(rows12(3).getString(0) === "special_character_quote_not_present") + assert(rows12(4).getString(0) === "special_character_underscore_present") + assert(rows12(5).getString(0) === "special_character_underscorenot_present") + } + test("SPARK-37038: Test TABLESAMPLE") { if (supportsTableSample) { withTable(s"$catalogName.new_table") { diff --git a/connector/kafka-0-10-assembly/pom.xml b/connector/kafka-0-10-assembly/pom.xml index e7d86b6fd7560..ae11f0eac307d 100644 --- a/connector/kafka-0-10-assembly/pom.xml +++ b/connector/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index 8f41efc15cacb..533a45e18f662 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml index b22b937cd821e..07ca1c2b2f3c7 100644 --- a/connector/kafka-0-10-token-provider/pom.xml +++ b/connector/kafka-0-10-token-provider/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml index 825868ebd9581..176d92da63801 100644 --- a/connector/kafka-0-10/pom.xml +++ b/connector/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/kinesis-asl-assembly/pom.xml b/connector/kinesis-asl-assembly/pom.xml index 312b9c460777a..a6ef06142f5cb 100644 --- a/connector/kinesis-asl-assembly/pom.xml +++ b/connector/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index 134f9c22d7436..4282e1f035716 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index 7b8b45704a5ef..2af6002b5c7db 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala index 5c4a5ff068968..d2417674837be 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -22,12 +22,12 @@ import scala.util.control.NonFatal import com.google.protobuf.DynamicMessage import com.google.protobuf.TypeRegistry -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.protobuf.utils.{ProtobufOptions, ProtobufUtils, SchemaConverters} -import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} private[sql] case class ProtobufDataToCatalyst( child: Expression, @@ -39,16 +39,8 @@ private[sql] case class ProtobufDataToCatalyst( override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) - override lazy val dataType: DataType = { - val dt = SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType - parseMode match { - // With PermissiveMode, the output Catalyst row might contain columns of null values for - // corrupt records, even if some of the columns are not nullable in the user-provided schema. - // Therefore we force the schema to be all nullable here. - case PermissiveMode => dt.asNullable - case _ => dt - } - } + override lazy val dataType: DataType = + SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType override def nullable: Boolean = true @@ -87,22 +79,9 @@ private[sql] case class ProtobufDataToCatalyst( mode } - @transient private lazy val nullResultRow: Any = dataType match { - case st: StructType => - val resultRow = new SpecificInternalRow(st.map(_.dataType)) - for (i <- 0 until st.length) { - resultRow.setNullAt(i) - } - resultRow - - case _ => - null - } - private def handleException(e: Throwable): Any = { parseMode match { - case PermissiveMode => - nullResultRow + case PermissiveMode => null case FailFastMode => throw QueryExecutionErrors.malformedProtobufMessageDetectedInMessageParsingError(e) case _ => diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index d3e63a11a66bf..62d0efd7459b2 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -79,20 +79,9 @@ class ProtobufCatalystDataConversionSuite .eval() } - val expected = { - val expectedSchema = ProtobufUtils.buildDescriptor(descBytes, badSchema) - SchemaConverters.toSqlType(expectedSchema).dataType match { - case st: StructType => - Row.fromSeq((0 until st.length).map { _ => - null - }) - case _ => null - } - } - checkEvaluation( ProtobufDataToCatalyst(binary, badSchema, Some(descBytes), Map("mode" -> "PERMISSIVE")), - expected) + expected = null) } protected def prepareExpectedResult(expected: Any): Any = expected match { @@ -137,7 +126,8 @@ class ProtobufCatalystDataConversionSuite while ( data != null && (data.get(0) == defaultValue || - (dt == BinaryType && + (dt.fields(0).dataType == BinaryType && + data.get(0) != null && data.get(0).asInstanceOf[Array[Byte]].isEmpty))) data = generator().asInstanceOf[Row] diff --git a/connector/spark-ganglia-lgpl/pom.xml b/connector/spark-ganglia-lgpl/pom.xml index a5870edfc7c81..a46c9bbfec2cf 100644 --- a/connector/spark-ganglia-lgpl/pom.xml +++ b/connector/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index c40f9905245cb..d1b0e82c7c0d5 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml @@ -243,6 +243,10 @@ org.scala-lang.modules scala-xml_${scala.binary.version} + + org.scala-lang.modules + scala-collection-compat_${scala.binary.version} + org.scala-lang scala-library diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index 91910b99ac999..2a580e341dc33 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -14,6 +14,7 @@ package org.apache.spark.io; import org.apache.spark.storage.StorageUtils; +import org.apache.spark.unsafe.Platform; import java.io.File; import java.io.IOException; @@ -39,7 +40,7 @@ public final class NioBufferedFileInputStream extends InputStream { private final FileChannel fileChannel; public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException { - byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes); + byteBuffer = Platform.allocateDirectBuffer(bufferSizeInBytes); fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); byteBuffer.flip(); } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 83352611770fd..08c080f5a5a1d 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -18,6 +18,7 @@ package org.apache.spark.memory; import javax.annotation.concurrent.GuardedBy; +import java.io.InterruptedIOException; import java.io.IOException; import java.nio.channels.ClosedByInterruptException; import java.util.Arrays; @@ -242,7 +243,7 @@ private long trySpillAndAcquire( cList.remove(idx); return 0; } - } catch (ClosedByInterruptException e) { + } catch (ClosedByInterruptException | InterruptedIOException e) { // This called by user to kill a task (e.g: speculative task). logger.error("error while calling spill() on " + consumerToSpill, e); throw new RuntimeException(e.getMessage()); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index a82f691d085d4..b097089282ce3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -150,11 +150,21 @@ public long[] getChecksums() { * Sorts the in-memory records and writes the sorted records to an on-disk file. * This method does not free the sort data structures. * - * @param isLastFile if true, this indicates that we're writing the final output file and that the - * bytes written should be counted towards shuffle spill metrics rather than - * shuffle write metrics. + * @param isFinalFile if true, this indicates that we're writing the final output file and that + * the bytes written should be counted towards shuffle write metrics rather + * than shuffle spill metrics. */ - private void writeSortedFile(boolean isLastFile) { + private void writeSortedFile(boolean isFinalFile) { + // Only emit the log if this is an actual spilling. + if (!isFinalFile) { + logger.info( + "Task {} on Thread {} spilling sort data of {} to disk ({} {} so far)", + taskContext.taskAttemptId(), + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() != 1 ? " times" : " time"); + } // This call performs the actual sort. final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = @@ -167,13 +177,14 @@ private void writeSortedFile(boolean isLastFile) { final ShuffleWriteMetricsReporter writeMetricsToUse; - if (isLastFile) { + if (isFinalFile) { // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. writeMetricsToUse = writeMetrics; } else { // We're spilling, so bytes written should be counted towards spill rather than write. // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count // them towards shuffle bytes written. + // The actual shuffle bytes written will be counted when we merge the spill files. writeMetricsToUse = new ShuffleWriteMetrics(); } @@ -246,7 +257,7 @@ private void writeSortedFile(boolean isLastFile) { spills.add(spillInfo); } - if (!isLastFile) { // i.e. this is a spill file + if (!isFinalFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter // relies on its `recordWritten()` method being called in order to trigger periodic updates to @@ -281,12 +292,6 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { return 0L; } - logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", - Thread.currentThread().getId(), - Utils.bytesToString(getMemoryUsage()), - spills.size(), - spills.size() > 1 ? " times" : " time"); - writeSortedFile(false); final long spillSize = freeMemory(); inMemSorter.reset(); @@ -440,8 +445,9 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p */ public SpillInfo[] closeAndGetSpills() throws IOException { if (inMemSorter != null) { - // Do not count the final file towards the spill count. - writeSortedFile(true); + // Here we are spilling the remaining data in the buffer. If there is no spill before, this + // final spill file will be the final shuffle output file. + writeSortedFile(/* isFinalFile = */spills.isEmpty()); freeMemory(); inMemSorter.free(); inMemSorter = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9c54184105951..d5b4eb138b1a6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -327,12 +327,6 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep logger.debug("Using slow merge"); mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); } catch (Exception e) { try { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index eb4d9d9abc8e3..38f0a60f8b0dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.sort.io; +import java.util.Collections; import java.util.Map; import java.util.Optional; @@ -56,7 +57,10 @@ public void initializeExecutor(String appId, String execId, Map if (blockManager == null) { throw new IllegalStateException("No blockManager available from the SparkEnv."); } - blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + blockResolver = + new IndexShuffleBlockResolver( + sparkConf, blockManager, Collections.emptyMap() /* Shouldn't be accessed */ + ); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index db79efd008530..cf29835b2ce89 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -28,6 +28,8 @@ import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.*; @@ -36,6 +38,7 @@ * of the file format). */ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { + private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); public static final int MAX_BUFFER_SIZE_BYTES = 16777216; // 16 mb private InputStream in; @@ -82,6 +85,15 @@ public UnsafeSorterSpillReader( Closeables.close(bs, /* swallowIOException = */ true); throw e; } + if (taskContext != null) { + taskContext.addTaskCompletionListener(context -> { + try { + close(); + } catch (IOException e) { + logger.info("error while closing UnsafeSorterSpillReader", e); + } + }); + } } @Override diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index b334bceb5a039..68dc8ba316dbf 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -192,7 +192,12 @@ $(document).ready(function() { }, {name: startedColumnName, data: 'startTime' }, {name: completedColumnName, data: 'endTime' }, - {name: durationColumnName, type: "title-numeric", data: 'duration' }, + { + name: durationColumnName, + type: "title-numeric", + data: 'duration', + render: (id, type, row) => `${row.duration}` + }, {name: 'user', data: 'sparkUser' }, {name: 'lastUpdated', data: 'lastUpdated' }, { diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js index 50bf959d3aa96..c7513c8268b26 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js @@ -650,6 +650,7 @@ $(document).ready(function () { } executorSummaryTableSelector.column(13).visible(dataToShow.showBytesSpilledData); executorSummaryTableSelector.column(14).visible(dataToShow.showBytesSpilledData); + reselectCheckboxesBasedOnTaskTableState(); }); // Prepare data for speculation metrics diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java b/core/src/main/scala-2.13/org/apache/spark/util/ArrayImplicits.scala similarity index 67% rename from sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java rename to core/src/main/scala-2.13/org/apache/spark/util/ArrayImplicits.scala index b84a3122cf84c..38c2a415af3db 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaBeanWithGenerics.java +++ b/core/src/main/scala-2.13/org/apache/spark/util/ArrayImplicits.scala @@ -15,27 +15,21 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst; +package org.apache.spark.util -class JavaBeanWithGenerics { - private A attribute; +import scala.collection.immutable - private T value; - - public A getAttribute() { - return attribute; - } - - public void setAttribute(A attribute) { - this.attribute = attribute; - } +/** + * Implicit methods related to Scala Array. + */ +private[spark] object ArrayImplicits { - public T getValue() { - return value; - } + implicit class SparkArrayOps[T](xs: Array[T]) { - public void setValue(T value) { - this.value = value; - } + /** + * Wraps an Array[T] as an immutable.ArraySeq[T] without copying. + */ + def toImmutableArraySeq: immutable.ArraySeq[T] = + immutable.ArraySeq.unsafeWrapArray(xs) + } } - diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index ecc0c891ea161..94ba3fe64a859 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -193,6 +193,8 @@ class BarrierTaskContext private[spark] ( override def isCompleted(): Boolean = taskContext.isCompleted() + override def isFailed(): Boolean = taskContext.isFailed() + override def isInterrupted(): Boolean = taskContext.isInterrupted() override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3495536a3508f..9a7a3b0c0e75e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -42,7 +42,6 @@ import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} import org.apache.spark.util._ -import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** @@ -151,17 +150,22 @@ private class ShuffleStatus( /** * Mapping from a mapId to the mapIndex, this is required to reduce the searching overhead within * the function updateMapOutput(mapId, bmAddress). + * + * Exposed for testing. */ - private[this] val mapIdToMapIndex = new OpenHashMap[Long, Int]() + private[spark] val mapIdToMapIndex = new HashMap[Long, Int]() /** * Register a map output. If there is already a registered location for the map output then it * will be replaced by the new location. */ def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { - if (mapStatuses(mapIndex) == null) { + val currentMapStatus = mapStatuses(mapIndex) + if (currentMapStatus == null) { _numAvailableMapOutputs += 1 invalidateSerializedMapOutputStatusCache() + } else { + mapIdToMapIndex.remove(currentMapStatus.mapId) } mapStatuses(mapIndex) = status mapIdToMapIndex(status.mapId) = mapIndex @@ -190,8 +194,8 @@ private class ShuffleStatus( mapStatus.updateLocation(bmAddress) invalidateSerializedMapOutputStatusCache() case None => - if (mapIndex.map(mapStatusesDeleted).exists(_.mapId == mapId)) { - val index = mapIndex.get + val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId == mapId) + if (index >= 0 && mapStatuses(index) == null) { val mapStatus = mapStatusesDeleted(index) mapStatus.updateLocation(bmAddress) mapStatuses(index) = mapStatus @@ -216,9 +220,11 @@ private class ShuffleStatus( */ def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock { logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}") - if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { + val currentMapStatus = mapStatuses(mapIndex) + if (currentMapStatus != null && currentMapStatus.location == bmAddress) { _numAvailableMapOutputs -= 1 - mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex) + mapIdToMapIndex.remove(currentMapStatus.mapId) + mapStatusesDeleted(mapIndex) = currentMapStatus mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } @@ -284,9 +290,11 @@ private class ShuffleStatus( */ def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { for (mapIndex <- mapStatuses.indices) { - if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { + val currentMapStatus = mapStatuses(mapIndex) + if (currentMapStatus != null && f(currentMapStatus.location)) { _numAvailableMapOutputs -= 1 - mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex) + mapIdToMapIndex.remove(currentMapStatus.mapId) + mapStatusesDeleted(mapIndex) = currentMapStatus mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 813a14acd19e4..f49e9e357c84d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -638,7 +638,9 @@ private[spark] object SparkConf extends Logging { DeprecatedConfig("spark.blacklist.killBlacklistedExecutors", "3.1.0", "Please use spark.excludeOnFailure.killExcludedExecutors"), DeprecatedConfig("spark.yarn.blacklist.executor.launch.blacklisting.enabled", "3.1.0", - "Please use spark.yarn.executor.launch.excludeOnFailure.enabled") + "Please use spark.yarn.executor.launch.excludeOnFailure.enabled"), + DeprecatedConfig("spark.network.remoteReadNioBufferConversion", "3.5.2", + "Please open a JIRA ticket to report it if you need to use this configuration.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4aea442bc3ce1..115f0663ef2b7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -281,12 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - SparkEnv.createDriverEnv( - conf, - isLocal, - listenerBus, - SparkContext.numDriverCores(master, conf), - this) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master, conf)) } private[spark] def env: SparkEnv = _env diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 272a0a6332bbe..edad91a0c6f0d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -169,7 +169,6 @@ object SparkEnv extends Logging { isLocal: Boolean, listenerBus: LiveListenerBus, numCores: Int, - sparkContext: SparkContext, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains(DRIVER_HOST_ADDRESS), s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!") @@ -192,7 +191,6 @@ object SparkEnv extends Logging { numCores, ioEncryptionKey, listenerBus = listenerBus, - Option(sparkContext), mockOutputCommitCoordinator = mockOutputCommitCoordinator ) } @@ -237,7 +235,6 @@ object SparkEnv extends Logging { /** * Helper method to create a SparkEnv for a driver or an executor. */ - // scalastyle:off argcount private def create( conf: SparkConf, executorId: String, @@ -248,9 +245,7 @@ object SparkEnv extends Logging { numUsableCores: Int, ioEncryptionKey: Option[Array[Byte]], listenerBus: LiveListenerBus = null, - sc: Option[SparkContext] = None, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - // scalastyle:on argcount val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER @@ -396,12 +391,7 @@ object SparkEnv extends Logging { } val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { - if (isDriver) { - new OutputCommitCoordinator(conf, isDriver, sc) - } else { - new OutputCommitCoordinator(conf, isDriver) - } - + new OutputCommitCoordinator(conf, isDriver) } val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 450c00928c9e6..af7aa4979dc1c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -94,6 +94,11 @@ abstract class TaskContext extends Serializable { */ def isCompleted(): Boolean + /** + * Returns true if the task has failed. + */ + def isFailed(): Boolean + /** * Returns true if the task has been killed. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 526627c28607d..46273a1b6d687 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -275,6 +275,8 @@ private[spark] class TaskContextImpl( @GuardedBy("this") override def isCompleted(): Boolean = synchronized(completed) + override def isFailed(): Boolean = synchronized(failureCauseOpt.isDefined) + override def isInterrupted(): Boolean = reasonIfKilled.isDefined override def getLocalProperty(key: String): String = localProperties.getProperty(key) diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index d30e9c5e2ce61..692ae45a12f4d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -33,6 +33,7 @@ import org.apache.spark.errors.SparkCoreErrors * A class to test Pickle serialization on the Scala side, that will be deserialized * in Python */ +@deprecated("This class will be move to `test`.", "3.5.2") case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable { def this() = this("", 0, 0.0) @@ -104,6 +105,7 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra * This object contains method to generate SequenceFile test data and write it to a * given directory (probably a temp directory) */ +@deprecated("This class will be move to `test`.", "3.5.2") object WriteInputFormatTestDataGenerator { def main(args: Array[String]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 60253ed5fda1f..0f0d8b6c07c0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -41,7 +41,7 @@ import org.apache.ivy.Ivy import org.apache.ivy.core.LogOptions import org.apache.ivy.core.module.descriptor._ import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId} -import org.apache.ivy.core.report.ResolveReport +import org.apache.ivy.core.report.{DownloadStatus, ResolveReport} import org.apache.ivy.core.resolve.ResolveOptions import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings @@ -683,7 +683,7 @@ private[spark] class SparkSubmit extends Logging { confKey = EXECUTOR_CORES.key), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN | KUBERNETES, ALL_DEPLOY_MODES, confKey = EXECUTOR_MEMORY.key), - OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = CORES_MAX.key), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = FILES.key), @@ -1226,7 +1226,7 @@ private[spark] object SparkSubmitUtils extends Logging { s"be whitespace. The artifactId provided is: ${splits(1)}") require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " + s"be whitespace. The version provided is: ${splits(2)}") - new MavenCoordinate(splits(0), splits(1), splits(2)) + MavenCoordinate(splits(0), splits(1), splits(2)) } } @@ -1241,21 +1241,27 @@ private[spark] object SparkSubmitUtils extends Logging { } /** - * Extracts maven coordinates from a comma-delimited string + * Create a ChainResolver used by Ivy to search for and resolve dependencies. + * * @param defaultIvyUserDir The default user path for Ivy + * @param useLocalM2AsCache Whether to use the local maven repo as a cache * @return A ChainResolver used by Ivy to search for and resolve dependencies. */ - def createRepoResolvers(defaultIvyUserDir: File): ChainResolver = { + def createRepoResolvers( + defaultIvyUserDir: File, + useLocalM2AsCache: Boolean = true): ChainResolver = { // We need a chain resolver if we want to check multiple repositories val cr = new ChainResolver cr.setName("spark-list") - val localM2 = new IBiblioResolver - localM2.setM2compatible(true) - localM2.setRoot(m2Path.toURI.toString) - localM2.setUsepoms(true) - localM2.setName("local-m2-cache") - cr.add(localM2) + if (useLocalM2AsCache) { + val localM2 = new IBiblioResolver + localM2.setM2compatible(true) + localM2.setRoot(m2Path.toURI.toString) + localM2.setUsepoms(true) + localM2.setName("local-m2-cache") + cr.add(localM2) + } val localIvy = new FileSystemResolver val localIvyRoot = new File(defaultIvyUserDir, "local") @@ -1351,18 +1357,23 @@ private[spark] object SparkSubmitUtils extends Logging { /** * Build Ivy Settings using options with default resolvers + * * @param remoteRepos Comma-delimited string of remote repositories other than maven central * @param ivyPath The path to the local ivy repository + * @param useLocalM2AsCache Whether or not use `local-m2 repo` as cache * @return An IvySettings object */ - def buildIvySettings(remoteRepos: Option[String], ivyPath: Option[String]): IvySettings = { + def buildIvySettings( + remoteRepos: Option[String], + ivyPath: Option[String], + useLocalM2AsCache: Boolean = true): IvySettings = { val ivySettings: IvySettings = new IvySettings processIvyPathArg(ivySettings, ivyPath) // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers - val repoResolver = createRepoResolvers(ivySettings.getDefaultIvyUserDir) + val repoResolver = createRepoResolvers(ivySettings.getDefaultIvyUserDir, useLocalM2AsCache) ivySettings.addResolver(repoResolver) ivySettings.setDefaultResolver(repoResolver.getName) processRemoteRepoArg(ivySettings, remoteRepos) @@ -1459,7 +1470,7 @@ private[spark] object SparkSubmitUtils extends Logging { */ private def clearIvyResolutionFiles( mdId: ModuleRevisionId, - ivySettings: IvySettings, + defaultCacheFile: File, ivyConfName: String): Unit = { val currentResolutionFiles = Seq( s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml", @@ -1467,14 +1478,40 @@ private[spark] object SparkSubmitUtils extends Logging { s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties" ) currentResolutionFiles.foreach { filename => - new File(ivySettings.getDefaultCache, filename).delete() + new File(defaultCacheFile, filename).delete() + } + } + + /** + * Clear invalid cache files in ivy. The cache file is usually at + * ~/.ivy2/cache/${groupId}/${artifactId}/ivy-${version}.xml, + * ~/.ivy2/cache/${groupId}/${artifactId}/ivy-${version}.xml.original, and + * ~/.ivy2/cache/${groupId}/${artifactId}/ivydata-${version}.properties. + * Because when using `local-m2` repo as a cache, some invalid files were created. + * If not deleted here, an error prompt similar to `unknown resolver local-m2-cache` + * will be generated, making some confusion for users. + */ + private def clearInvalidIvyCacheFiles( + mdId: ModuleRevisionId, + defaultCacheFile: File): Unit = { + val cacheFiles = Seq( + s"${mdId.getOrganisation}${File.separator}${mdId.getName}${File.separator}" + + s"ivy-${mdId.getRevision}.xml", + s"${mdId.getOrganisation}${File.separator}${mdId.getName}${File.separator}" + + s"ivy-${mdId.getRevision}.xml.original", + s"${mdId.getOrganisation}${File.separator}${mdId.getName}${File.separator}" + + s"ivydata-${mdId.getRevision}.properties") + cacheFiles.foreach { filename => + new File(defaultCacheFile, filename).delete() } } /** * Resolves any dependencies that were supplied through maven coordinates + * * @param coordinates Comma-delimited string of maven coordinates * @param ivySettings An IvySettings containing resolvers to use + * @param noCacheIvySettings An no-cache(local-m2-cache) IvySettings containing resolvers to use * @param transitive Whether resolving transitive dependencies, default is true * @param exclusions Exclusions to apply when resolving transitive dependencies * @return Seq of path to the jars of the given maven artifacts including their @@ -1483,6 +1520,7 @@ private[spark] object SparkSubmitUtils extends Logging { def resolveMavenCoordinates( coordinates: String, ivySettings: IvySettings, + noCacheIvySettings: Option[IvySettings] = None, transitive: Boolean, exclusions: Seq[String] = Nil, isTest: Boolean = false): Seq[String] = { @@ -1511,6 +1549,8 @@ private[spark] object SparkSubmitUtils extends Logging { // scalastyle:on println val ivy = Ivy.newInstance(ivySettings) + ivy.pushContext() + // Set resolve options to download transitive dependencies as well val resolveOptions = new ResolveOptions resolveOptions.setTransitive(transitive) @@ -1523,6 +1563,11 @@ private[spark] object SparkSubmitUtils extends Logging { } else { resolveOptions.setDownload(true) } + // retrieve all resolved dependencies + retrieveOptions.setDestArtifactPattern( + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision](-[classifier]).[ext]") + retrieveOptions.setConfs(Array(ivyConfName)) // Add exclusion rules for Spark and Scala Library addExclusionRules(ivySettings, ivyConfName, md) @@ -1534,17 +1579,44 @@ private[spark] object SparkSubmitUtils extends Logging { // resolve dependencies val rr: ResolveReport = ivy.resolve(md, resolveOptions) if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) + // SPARK-46302: When there are some corrupted jars in the local maven repo, + // we try to continue without the cache + val failedReports = rr.getArtifactsReports(DownloadStatus.FAILED, true) + if (failedReports.nonEmpty && noCacheIvySettings.isDefined) { + val failedArtifacts = failedReports.map(r => r.getArtifact) + logInfo(s"Download failed: ${failedArtifacts.mkString("[", ", ", "]")}, " + + s"attempt to retry while skipping local-m2-cache.") + failedArtifacts.foreach(artifact => { + clearInvalidIvyCacheFiles(artifact.getModuleRevisionId, ivySettings.getDefaultCache) + }) + ivy.popContext() + + val noCacheIvy = Ivy.newInstance(noCacheIvySettings.get) + noCacheIvy.pushContext() + + val noCacheRr = noCacheIvy.resolve(md, resolveOptions) + if (noCacheRr.hasError) { + throw new RuntimeException(noCacheRr.getAllProblemMessages.toString) + } + noCacheIvy.retrieve(noCacheRr.getModuleDescriptor.getModuleRevisionId, retrieveOptions) + val dependencyPaths = resolveDependencyPaths( + noCacheRr.getArtifacts.toArray, packagesDirectory) + noCacheIvy.popContext() + + dependencyPaths + } else { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + } else { + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, retrieveOptions) + val dependencyPaths = resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + ivy.popContext() + + dependencyPaths } - // retrieve all resolved dependencies - retrieveOptions.setDestArtifactPattern(packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision](-[classifier]).[ext]") - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } finally { System.setOut(sysOut) - clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings, ivyConfName) + clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings.getDefaultCache, ivyConfName) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index a3fe5153bee9f..93dd25db0937b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -567,7 +567,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --kill SUBMISSION_ID If given, kills the driver specified. | --status SUBMISSION_ID If given, requests the status of the driver specified. | - | Spark standalone, Mesos and Kubernetes only: + | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. | | Spark standalone, YARN and Kubernetes only: diff --git a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileReaders.scala b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileReaders.scala index b21c67a2823af..8c3dda4727784 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileReaders.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileReaders.scala @@ -119,7 +119,9 @@ object EventLogFileReader extends Logging { if (isSingleEventLog(status)) { Some(new SingleFileEventLogFileReader(fs, status.getPath, Option(status))) } else if (isRollingEventLogs(status)) { - if (fs.listStatus(status.getPath).exists(RollingEventLogFilesWriter.isEventLogFile)) { + val files = fs.listStatus(status.getPath) + if (files.exists(RollingEventLogFilesWriter.isEventLogFile) && + files.exists(RollingEventLogFilesWriter.isAppStatusFile)) { Some(new RollingEventLogFilesFileReader(fs, status.getPath)) } else { logDebug(s"Rolling event log directory have no event log file at ${status.getPath}") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 49b479f3124e9..387bc7d9e45b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -925,11 +925,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * UI lifecycle. */ private def invalidateUI(appId: String, attemptId: Option[String]): Unit = { - synchronized { - activeUIs.get((appId, attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() - } + val uiOption = synchronized { + activeUIs.get((appId, attemptId)) + } + uiOption.foreach { ui => + ui.invalidate() + ui.ui.store.close() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 9e10a0bbf3964..31d541368ab45 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -21,6 +21,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node +import org.apache.commons.lang3.StringUtils + import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.StandaloneResourceUtils.{formatResourceRequirements, formatResourcesAddresses} @@ -93,10 +95,14 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
  • State: {app.state}
  • { if (!app.isFinished) { -
  • - Application Detail UI -
  • + if (StringUtils.isBlank(app.desc.appUiUrl)) { +
  • Application UI: Disabled
  • + } else { +
  • + Application Detail UI +
  • + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index a71eb33a2fe1d..078ed102f0bbd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -21,6 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node +import org.apache.commons.lang3.StringUtils import org.json4s.JValue import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, MasterStateResponse, RequestKillDriver, RequestMasterState} @@ -289,7 +290,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { { - if (app.isFinished) { + if (app.isFinished || StringUtils.isBlank(app.desc.appUiUrl)) { app.desc.name } else { {formatResourcesAddresses(driver.resources)} {driver.desc.command.arguments(2)} {if (showDuration) { - {UIUtils.formatDuration(System.currentTimeMillis() - driver.startTime)} + + {UIUtils.formatDuration(System.currentTimeMillis() - driver.startTime)} + }} } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index af94bd6d9e0f2..53e5c5ac2a8f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -23,6 +23,7 @@ import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.Master import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.DECOMMISSION_ENABLED import org.apache.spark.internal.config.UI.MASTER_UI_DECOMMISSION_ALLOW_MODE import org.apache.spark.internal.config.UI.UI_KILL_ENABLED import org.apache.spark.ui.{SparkUI, WebUI} @@ -40,6 +41,7 @@ class MasterWebUI( val masterEndpointRef = master.self val killEnabled = master.conf.get(UI_KILL_ENABLED) + val decommissionDisabled = !master.conf.get(DECOMMISSION_ENABLED) val decommissionAllowMode = master.conf.get(MASTER_UI_DECOMMISSION_ALLOW_MODE) initialize() @@ -58,7 +60,7 @@ class MasterWebUI( override def doPost(req: HttpServletRequest, resp: HttpServletResponse): Unit = { val hostnames: Seq[String] = Option(req.getParameterValues("host")) .getOrElse(Array[String]()).toSeq - if (!isDecommissioningRequestAllowed(req)) { + if (decommissionDisabled || !isDecommissioningRequestAllowed(req)) { resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) } else { val removedWorkers = masterEndpointRef.askSync[Integer]( diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 6ec281f5b4406..c3f931f356ea7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -146,6 +146,9 @@ private[deploy] class HadoopFSDelegationTokenProvider val tokenKind = token.getKind.toString val interval = newExpiration - getIssueDate(tokenKind, identifier) logInfo(s"Renewal interval is $interval for token $tokenKind") + // The token here is only used to obtain renewal intervals. We should cancel it in + // a timely manner to avoid causing additional pressure on the server. + token.cancel(hadoopConf) interval }.toOption } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 3171d3f16e8a0..e740b328dd7b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -212,8 +212,8 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { {formatResourcesAddresses(driver.resources)} - stdout - stderr + stdout + stderr {driver.finalException.getOrElse("")} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ab238626efe9b..537522326fc78 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -20,9 +20,9 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer import java.util.Locale -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} -import scala.collection.mutable import scala.util.{Failure, Success} import scala.util.control.NonFatal @@ -71,12 +71,19 @@ private[spark] class CoarseGrainedExecutorBackend( /** * Map each taskId to the information about the resource allocated to it, Please refer to * [[ResourceInformation]] for specifics. + * CHM is used to ensure thread-safety (https://issues.apache.org/jira/browse/SPARK-45227) * Exposed for testing only. */ - private[executor] val taskResources = new mutable.HashMap[Long, Map[String, ResourceInformation]] + private[executor] val taskResources = new ConcurrentHashMap[ + Long, Map[String, ResourceInformation] + ] private var decommissioned = false + // Track the last time in ns that at least one task is running. If no task is running and all + // shuffle/RDD data migration are done, the decommissioned executor should exit. + private var lastTaskFinishTime = new AtomicLong(System.nanoTime()) + override def onStart(): Unit = { if (env.conf.get(DECOMMISSION_ENABLED)) { val signal = env.conf.get(EXECUTOR_DECOMMISSION_SIGNAL) @@ -184,7 +191,7 @@ private[spark] class CoarseGrainedExecutorBackend( } else { val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) - taskResources(taskDesc.taskId) = taskDesc.resources + taskResources.put(taskDesc.taskId, taskDesc.resources) executor.launchTask(this, taskDesc) } @@ -261,11 +268,12 @@ private[spark] class CoarseGrainedExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit = { - val resources = taskResources.getOrElse(taskId, Map.empty[String, ResourceInformation]) + val resources = taskResources.getOrDefault(taskId, Map.empty[String, ResourceInformation]) val cpus = executor.runningTasks.get(taskId).taskDescription.cpus val msg = StatusUpdate(executorId, taskId, state, data, cpus, resources) if (TaskState.isFinished(state)) { taskResources.remove(taskId) + lastTaskFinishTime.set(System.nanoTime()) } driver match { case Some(driverRef) => driverRef.send(msg) @@ -338,7 +346,6 @@ private[spark] class CoarseGrainedExecutorBackend( val shutdownThread = new Thread("wait-for-blocks-to-migrate") { override def run(): Unit = { - var lastTaskRunningTime = System.nanoTime() val sleep_time = 1000 // 1s // This config is internal and only used by unit tests to force an executor // to hang around for longer when decommissioned. @@ -355,7 +362,7 @@ private[spark] class CoarseGrainedExecutorBackend( val (migrationTime, allBlocksMigrated) = env.blockManager.lastMigrationInfo() // We can only trust allBlocksMigrated boolean value if there were no tasks running // since the start of computing it. - if (allBlocksMigrated && (migrationTime > lastTaskRunningTime)) { + if (allBlocksMigrated && (migrationTime > lastTaskFinishTime.get())) { logInfo("No running tasks, all blocks migrated, stopping.") exitExecutor(0, ExecutorLossMessage.decommissionFinished, notifyDriver = true) } else { @@ -367,12 +374,6 @@ private[spark] class CoarseGrainedExecutorBackend( } } else { logInfo(s"Blocked from shutdown by ${executor.numRunningTasks} running tasks") - // If there is a running task it could store blocks, so make sure we wait for a - // migration loop to complete after the last task is done. - // Note: this is only advanced if there is a running task, if there - // is no running task but the blocks are not done migrating this does not - // move forward. - lastTaskRunningTime = System.nanoTime() } Thread.sleep(sleep_time) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9290b5b36a8f7..69a91839bbeb5 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -530,7 +530,7 @@ private[spark] class Executor( // Collect latest accumulator values to report back to the driver val accums: Seq[AccumulatorV2[_, _]] = Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val accUpdates = accums.map(acc => acc.toInfoUpdate) setTaskFinishedAndClearInterruptStatus() (accums, accUpdates) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 78b39b0cbda68..e88b70eb655c5 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} @@ -30,7 +30,6 @@ import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} import org.apache.spark.util._ - /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. @@ -150,6 +149,11 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v.asJava) + private val (readLock, writeLock) = { + val lock = new ReentrantReadWriteLock() + (lock.readLock(), lock.writeLock()) + } + /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted * data, defined only in tasks with input. @@ -264,15 +268,46 @@ class TaskMetrics private[spark] () extends Serializable { /** * External accumulators registered with this task. */ - @transient private[spark] lazy val _externalAccums = new CopyOnWriteArrayList[AccumulatorV2[_, _]] + @transient private[spark] lazy val _externalAccums = new ArrayBuffer[AccumulatorV2[_, _]] + + /** + * Perform an `op` conversion on the `_externalAccums` within the read lock. + * + * Note `op` is expected to not modify the `_externalAccums` and not being + * lazy evaluation for safe concern since `ArrayBuffer` is lazily evaluated. + * And we intentionally keeps `_externalAccums` as mutable instead of converting + * it to immutable for the performance concern. + */ + private[spark] def withExternalAccums[T](op: ArrayBuffer[AccumulatorV2[_, _]] => T) + : T = withReadLock { + op(_externalAccums) + } - private[spark] def externalAccums = _externalAccums.asScala + private def withReadLock[B](fn: => B): B = { + readLock.lock() + try { + fn + } finally { + readLock.unlock() + } + } - private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = { - _externalAccums.add(a) + private def withWriteLock[B](fn: => B): B = { + writeLock.lock() + try { + fn + } finally { + writeLock.unlock() + } } - private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums + private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = withWriteLock { + _externalAccums += a + } + + private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = withReadLock { + internalAccums ++ _externalAccums + } private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its @@ -335,7 +370,7 @@ private[spark] object TaskMetrics extends Logging { tmAcc.metadata = acc.metadata tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) } else { - tm._externalAccums.add(acc) + tm._externalAccums += acc } } tm diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 600cbf151e17b..938d6ec2e01b0 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -924,7 +924,7 @@ package object config { private[spark] val MAX_EXECUTOR_FAILURES = ConfigBuilder("spark.executor.maxNumFailures") - .doc("Spark exits if the number of failed executors exceeds this threshold. " + + .doc("The maximum number of executor failures before failing the application. " + "This configuration only takes effect on YARN, or Kubernetes when " + "`spark.kubernetes.allocation.pods.allocator` is set to 'direct'.") .version("3.5.0") @@ -933,7 +933,7 @@ package object config { private[spark] val EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS = ConfigBuilder("spark.executor.failuresValidityInterval") - .doc("Interval after which Executor failures will be considered independent and not " + + .doc("Interval after which executor failures will be considered independent and not " + "accumulate towards the attempt count. This configuration only takes effect on YARN, " + "or Kubernetes when `spark.kubernetes.allocation.pods.allocator` is set to 'direct'.") .version("3.5.0") @@ -1155,7 +1155,7 @@ package object config { "like YARN and event logs.") .version("2.1.2") .regexConf - .createWithDefault("(?i)secret|password|token|access[.]key".r) + .createWithDefault("(?i)secret|password|token|access[.]?key".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a21d2ae773966..f695b10202758 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -223,14 +223,17 @@ abstract class RDD[T: ClassTag]( * not use `this` because RDDs are user-visible, so users might have added their own locking on * RDDs; sharing that could lead to a deadlock. * - * One thread might hold the lock on many of these, for a chain of RDD dependencies; but - * because DAGs are acyclic, and we only ever hold locks for one path in that DAG, there is no - * chance of deadlock. + * One thread might hold the lock on many of these, for a chain of RDD dependencies. Deadlocks + * are possible if we try to lock another resource while holding the stateLock, + * and the lock acquisition sequence of these locks is not guaranteed to be the same. + * This can lead lead to a deadlock as one thread might first acquire the stateLock, + * and then the resource, + * while another thread might first acquire the resource, and then the stateLock. * * Executors may reference the shared fields (though they should never mutate them, * that only happens on the driver). */ - private val stateLock = new Serializable {} + private[spark] val stateLock = new Serializable {} // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 0a93023443704..3c1451a01850d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -76,8 +76,10 @@ class UnionRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { val parRDDs = if (isPartitionListingParallel) { + // scalastyle:off parvector val parArray = new ParVector(rdds.toVector) parArray.tasksupport = UnionRDD.partitionEvalTaskSupport + // scalastyle:on parvector parArray } else { rdds diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala index 9f98d4d9c9c79..afbacb8013645 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala @@ -67,9 +67,10 @@ private[spark] class ResourceProfileManager(sparkConf: SparkConf, */ private[spark] def isSupported(rp: ResourceProfile): Boolean = { if (rp.isInstanceOf[TaskResourceProfile] && !dynamicEnabled) { - if ((notRunningUnitTests || testExceptionThrown) && !isStandaloneOrLocalCluster) { - throw new SparkException("TaskResourceProfiles are only supported for Standalone " + - "cluster for now when dynamic allocation is disabled.") + if ((notRunningUnitTests || testExceptionThrown) && + !(isStandaloneOrLocalCluster || isYarn || isK8s)) { + throw new SparkException("TaskResourceProfiles are only supported for Standalone, " + + "Yarn and Kubernetes cluster for now when dynamic allocation is disabled.") } } else { val isNotDefaultProfile = rp.id != ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index fc83439454dcf..89d16e5793482 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -174,6 +174,9 @@ private[spark] class DAGScheduler( * locations where that RDD partition is cached. * * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + * If you need to access any RDD while synchronizing on the cache locations, + * first synchronize on the RDD, and then synchronize on this map to avoid deadlocks. The RDD + * could try to access the cache locations after synchronizing on the RDD. */ private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] @@ -420,22 +423,24 @@ private[spark] class DAGScheduler( } private[scheduler] - def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { - // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times - if (!cacheLocs.contains(rdd.id)) { - // Note: if the storage level is NONE, we don't need to get locations from block manager. - val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { - IndexedSeq.fill(rdd.partitions.length)(Nil) - } else { - val blockIds = - rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - blockManagerMaster.getLocations(blockIds).map { bms => - bms.map(bm => TaskLocation(bm.host, bm.executorId)) + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = rdd.stateLock.synchronized { + cacheLocs.synchronized { + // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times + if (!cacheLocs.contains(rdd.id)) { + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } } + cacheLocs(rdd.id) = locs } - cacheLocs(rdd.id) = locs + cacheLocs(rdd.id) } - cacheLocs(rdd.id) } private def clearCacheLocs(): Unit = cacheLocs.synchronized { @@ -1847,9 +1852,9 @@ private[spark] class DAGScheduler( case Success => // An earlier attempt of a stage (which is zombie) may still have running tasks. If these // tasks complete, they still count and we can mark the corresponding partitions as - // finished. Here we notify the task scheduler to skip running tasks for the same partition, - // to save resource. - if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { + // finished if the stage is determinate. Here we notify the task scheduler to skip running + // tasks for the same partition to save resource. + if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) { taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) } @@ -1903,19 +1908,26 @@ private[spark] class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] - shuffleStage.pendingPartitions -= task.partitionId - val status = event.result.asInstanceOf[MapStatus] - val execId = status.location.executorId - logDebug("ShuffleMapTask finished on " + execId) - if (executorFailureEpoch.contains(execId) && + // Ignore task completion for old attempt of indeterminate stage + val ignoreIndeterminate = stage.isIndeterminate && + task.stageAttemptId < stage.latestInfo.attemptNumber() + if (!ignoreIndeterminate) { + shuffleStage.pendingPartitions -= task.partitionId + val status = event.result.asInstanceOf[MapStatus] + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (executorFailureEpoch.contains(execId) && smt.epoch <= executorFailureEpoch(execId)) { - logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") + } else { + // The epoch of the task is acceptable (i.e., the task was launched after the most + // recent failure we're aware of for the executor), so mark the task's output as + // available. + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) + } } else { - // The epoch of the task is acceptable (i.e., the task was launched after the most - // recent failure we're aware of for the executor), so mark the task's output as - // available. - mapOutputTracker.registerMapOutput( - shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) + logInfo(s"Ignoring $smt completion from an older attempt of indeterminate stage") } if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index cd5d6b8f9c90d..a5858ebf9cdcc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -44,10 +44,7 @@ private case class AskPermissionToCommitOutput( * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. */ -private[spark] class OutputCommitCoordinator( - conf: SparkConf, - isDriver: Boolean, - sc: Option[SparkContext] = None) extends Logging { +private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) extends Logging { // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None @@ -158,10 +155,9 @@ private[spark] class OutputCommitCoordinator( val taskId = TaskIdentifier(stageAttempt, attemptNumber) stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId if (stageState.authorizedCommitters(partition) == taskId) { - sc.foreach(_.dagScheduler.stageFailed(stage, s"Authorized committer " + - s"(attemptNumber=$attemptNumber, stage=$stage, partition=$partition) failed; " + - s"but task commit success, data duplication may happen. " + - s"reason=$reason")) + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") + stageState.authorizedCommitters(partition) = null } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 39667ea2364db..69ef094417b68 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -210,7 +210,7 @@ private[spark] abstract class Task[T]( context.taskMetrics.nonZeroInternalAccums() ++ // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not // filter them out. - context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) + context.taskMetrics.withExternalAccums(_.filter(a => !taskFailed || a.countFailedValues)) } else { Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0cb970fd27880..5e9716dfcfe90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -901,7 +901,7 @@ private[spark] class TaskSchedulerImpl( executorRunTime = acc.value.asInstanceOf[Long] } } - acc.toInfo(Some(acc.value), None) + acc.toInfoUpdate } val taskProcessRate = if (efficientTaskCalcualtionEnabled) { getTaskProcessRate(recordsRead, executorRunTime) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 919b0f5f7c135..34eea575bbfd2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,6 +21,7 @@ import java.io._ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.file.Files +import java.util.{Collections, Map => JMap} import scala.collection.mutable.ArrayBuffer @@ -37,6 +38,7 @@ import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils +import org.apache.spark.util.collection.OpenHashSet /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -52,7 +54,8 @@ import org.apache.spark.util.Utils private[spark] class IndexShuffleBlockResolver( conf: SparkConf, // var for testing - var _blockManager: BlockManager = null) + var _blockManager: BlockManager = null, + val taskIdMapsForShuffle: JMap[Int, OpenHashSet[Long]] = Collections.emptyMap()) extends ShuffleBlockResolver with Logging with MigratableResolver { @@ -270,6 +273,21 @@ private[spark] class IndexShuffleBlockResolver( throw SparkCoreErrors.failedRenameTempFileError(fileTmp, file) } } + blockId match { + case ShuffleIndexBlockId(shuffleId, mapId, _) => + val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( + shuffleId, _ => new OpenHashSet[Long](8) + ) + mapTaskIds.add(mapId) + + case ShuffleDataBlockId(shuffleId, mapId, _) => + val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( + shuffleId, _ => new OpenHashSet[Long](8) + ) + mapTaskIds.add(mapId) + + case _ => // Unreachable + } blockManager.reportBlockStatus(blockId, BlockStatus(StorageLevel.DISK_ONLY, 0, diskSize)) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 46aca07ce43f6..4234d0ec5fd04 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -87,7 +87,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + override val shuffleBlockResolver = + new IndexShuffleBlockResolver(conf, taskIdMapsForShuffle = taskIdMapsForShuffle) /** * Obtains a [[ShuffleHandle]] to pass to tasks. @@ -176,7 +177,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics, shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) + new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 8613fe11a4c2f..3be7d24f7e4ec 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -21,6 +21,7 @@ import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.util.collection.ExternalSorter @@ -28,6 +29,7 @@ private[spark] class SortShuffleWriter[K, V, C]( handle: BaseShuffleHandle[K, V, C], mapId: Long, context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, shuffleExecutorComponents: ShuffleExecutorComponents) extends ShuffleWriter[K, V] with Logging { @@ -46,8 +48,6 @@ private[spark] class SortShuffleWriter[K, V, C]( private var partitionLengths: Array[Long] = _ - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { @@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics) partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 45ebb6eafa69f..ab4073fe8c05c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -415,13 +415,14 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false * then just go ahead and acquire the write lock. Otherwise, if another thread is already * writing the block, then we wait for the write to finish before acquiring the read lock. * - * @return true if the block did not already exist, false otherwise. If this returns false, then - * a read lock on the existing block will be held. If this returns true, a write lock on - * the new block will be held. + * @return true if the block did not already exist, false otherwise. + * If this returns true, a write lock on the new block will be held. + * If this returns false then a read lock will be held iff keepReadLock == true. */ def lockNewBlockForWriting( blockId: BlockId, - newBlockInfo: BlockInfo): Boolean = { + newBlockInfo: BlockInfo, + keepReadLock: Boolean = true): Boolean = { logTrace(s"Task $currentTaskAttemptId trying to put $blockId") // Get the lock that will be associated with the to-be written block and lock it for the entire // duration of this operation. This way we prevent race conditions when two threads try to write @@ -449,6 +450,8 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false val result = lockForWriting(blockId, blocking = false) assert(result.isDefined) return true + } else if (!keepReadLock) { + return false } else { // Block already exists. This could happen if another thread races with us to compute // the same block. In this case we try to acquire a read lock, if the locking succeeds diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 05d57c67576a5..1b56aa7ade125 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1510,14 +1510,10 @@ private[spark] class BlockManager( val putBlockInfo = { val newInfo = new BlockInfo(level, classTag, tellMaster) - if (blockInfoManager.lockNewBlockForWriting(blockId, newInfo)) { + if (blockInfoManager.lockNewBlockForWriting(blockId, newInfo, keepReadLock)) { newInfo } else { logWarning(s"Block $blockId already exists on this machine; not re-adding it") - if (!keepReadLock) { - // lockNewBlockForWriting returned a read lock on the existing block, so we must free it: - releaseLock(blockId) - } return None } } @@ -2086,8 +2082,10 @@ private[spark] class BlockManager( hasRemoveBlock = true if (tellMaster) { // Only update storage level from the captured block status before deleting, so that - // memory size and disk size are being kept for calculating delta. - reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = StorageLevel.NONE)) + // memory size and disk size are being kept for calculating delta. Reset the replica + // count 0 in storage level to notify that it is a remove operation. + val storageLevel = StorageLevel(blockStatus.get.storageLevel.toInt, 0) + reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = storageLevel)) } } finally { if (!hasRemoveBlock) { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index f8bd73e65617f..ebc10300ef471 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -125,6 +125,12 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsCommitted = 0L + // For testing only. + private[storage] def getSerializerWrappedStream: OutputStream = bs + + // For testing only. + private[storage] def getSerializationStream: SerializationStream = objOut + /** * Set the checksum that the checksumOutputStream should use */ @@ -173,19 +179,35 @@ private[spark] class DiskBlockObjectWriter( * Should call after committing or reverting partial writes. */ private def closeResources(): Unit = { - if (initialized) { - Utils.tryWithSafeFinally { - mcs.manualClose() - } { - channel = null - mcs = null - bs = null - fos = null - ts = null - objOut = null - initialized = false - streamOpen = false - hasBeenClosed = true + try { + if (streamOpen) { + Utils.tryWithSafeFinally { + if (null != objOut) objOut.close() + bs = null + } { + objOut = null + if (null != bs) bs.close() + bs = null + } + } + } catch { + case e: IOException => + logInfo("Exception occurred while closing the output stream: " + e.getMessage) + } finally { + if (initialized) { + Utils.tryWithSafeFinally { + mcs.manualClose() + } { + channel = null + mcs = null + bs = null + fos = null + ts = null + objOut = null + initialized = false + streamOpen = false + hasBeenClosed = true + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1cb5adef5f460..304aa01c7ee42 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -34,6 +34,7 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer @@ -309,7 +310,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: private var _transferred = 0L - private val buffer = ByteBuffer.allocateDirect(64 * 1024) + private val buffer = Platform.allocateDirectBuffer(64 * 1024) buffer.flip() override def count(): Long = blockSize diff --git a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala index eb23fb4b1c84d..161120393490f 100644 --- a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala +++ b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala @@ -188,15 +188,15 @@ private[spark] object FallbackStorage extends Logging { val name = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID).name val hash = JavaUtils.nonNegativeHash(name) val dataFile = new Path(fallbackPath, s"$appId/$shuffleId/$hash/$name") - val f = fallbackFileSystem.open(dataFile) val size = nextOffset - offset logDebug(s"To byte array $size") val array = new Array[Byte](size.toInt) val startTimeNs = System.nanoTime() - f.seek(offset) - f.readFully(array) - logDebug(s"Took ${(System.nanoTime() - startTimeNs) / (1000 * 1000)}ms") - f.close() + Utils.tryWithResource(fallbackFileSystem.open(dataFile)) { f => + f.seek(offset) + f.readFully(array) + logDebug(s"Took ${(System.nanoTime() - startTimeNs) / (1000 * 1000)}ms") + } new NioManagedBuffer(ByteBuffer.wrap(array)) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b21a2aa1c1791..17407f4ee21f5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1142,6 +1142,12 @@ final class ShuffleBlockFetcherIterator( s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." logWarning(diagnosisResponse) diagnosisResponse + case shuffleBlockBatch: ShuffleBlockBatchId => + val diagnosisResponse = s"BlockBatch $shuffleBlockBatch is corrupted " + + s"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + s"ShuffleBlockBatchId" + logWarning(diagnosisResponse) + diagnosisResponse case unexpected: BlockId => throw SparkException.internalError( s"Unexpected type of BlockId, $unexpected", category = "STORAGE") @@ -1354,7 +1360,8 @@ private class BufferReleasingInputStream( } } - override def available(): Int = delegate.available() + override def available(): Int = + tryOrFetchFailedException(delegate.available()) override def mark(readlimit: Int): Unit = delegate.mark(readlimit) @@ -1369,12 +1376,13 @@ private class BufferReleasingInputStream( override def read(b: Array[Byte], off: Int, len: Int): Int = tryOrFetchFailedException(delegate.read(b, off, len)) - override def reset(): Unit = delegate.reset() + override def reset(): Unit = tryOrFetchFailedException(delegate.reset()) /** * Execute a block of code that returns a value, close this stream quietly and re-throwing * IOException as FetchFailedException when detectCorruption is true. This method is only - * used by the `read` and `skip` methods inside `BufferReleasingInputStream` currently. + * used by the `available`, `read` and `skip` methods inside `BufferReleasingInputStream` + * currently. */ private def tryOrFetchFailedException[T](block: => T): T = { try { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 9582bdbf52641..21753361e627a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -312,6 +312,12 @@ private[spark] object JettyUtils extends Logging { logDebug(s"Using requestHeaderSize: $requestHeaderSize") httpConfig.setRequestHeaderSize(requestHeaderSize) + // Hide information. + logDebug("Using setSendServerVersion: false") + httpConfig.setSendServerVersion(false) + logDebug("Using setSendXPoweredBy: false") + httpConfig.setSendXPoweredBy(false) + // If SSL is configured, create the secure connector first. val securePort = sslOptions.createJettySslContextFactory().map { factory => val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 286c0a1625150..695f6d54e8f53 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -551,8 +551,8 @@ private[spark] object UIUtils extends Logging { * the whole string will rendered as a simple escaped text. * * Note: In terms of security, only anchor tags with root relative links are supported. So any - * attempts to embed links outside Spark UI, or other tags like <script> will cause in - * the whole description to be treated as plain text. + * attempts to embed links outside Spark UI, other tags like <script>, or inline scripts + * like `onclick` will cause in the whole description to be treated as plain text. * * @param desc the original job or stage description string, which may contain html tags. * @param basePathUri with which to prepend the relative links; this is used when plainText is @@ -572,7 +572,13 @@ private[spark] object UIUtils extends Logging { // Verify that this has only anchors and span (we are wrapping in span) val allowedNodeLabels = Set("a", "span", "br") - val illegalNodes = (xml \\ "_").filterNot(node => allowedNodeLabels.contains(node.label)) + val allowedAttributes = Set("class", "href") + val illegalNodes = + (xml \\ "_").filterNot { node => + allowedNodeLabels.contains(node.label) && + // Verify we only have href attributes + node.attributes.map(_.key).forall(allowedAttributes.contains) + } if (illegalNodes.nonEmpty) { throw new IllegalArgumentException( "Only HTML anchors allowed in job descriptions\n" + diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 181033c9d20c8..aadde1e20226a 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -102,16 +102,24 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { metadata.countFailedValues } + private def isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) + /** * Creates an [[AccumulableInfo]] representation of this [[AccumulatorV2]] with the provided * values. */ private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { - val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) AccumulableInfo(id, name, internOption(update), internOption(value), isInternal, countFailedValues) } + /** + * Creates an [[AccumulableInfo]] representation of this [[AccumulatorV2]] as an update. + */ + private[spark] def toInfoUpdate: AccumulableInfo = { + AccumulableInfo(id, name, internOption(Some(value)), None, isInternal, countFailedValues) + } + final private[spark] def isAtDriverSide: Boolean = atDriverSide /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 025e5d5bac94b..377caf776deb0 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -1350,7 +1350,7 @@ private[spark] object JsonProtocol extends JsonUtils { val accumUpdates = jsonOption(json.get("Accumulator Updates")) .map(_.extractElements.map(accumulableInfoFromJson).toArray.toSeq) .getOrElse(taskMetricsFromJson(json.get("Metrics")).accumulators().map(acc => { - acc.toInfo(Some(acc.value), None) + acc.toInfoUpdate }).toArray.toSeq) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 16d7de56c39eb..2d3d6ec89ffbd 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -363,6 +363,10 @@ private[spark] object ThreadUtils { * Comparing to the map() method of Scala parallel collections, this method can be interrupted * at any time. This is useful on canceling of task execution, for example. * + * Functions are guaranteed to be executed in freshly-created threads that inherit the calling + * thread's Spark thread-local variables. These threads also inherit the calling thread's active + * SparkSession. + * * @param in - the input collection which should be transformed in parallel. * @param prefix - the prefix assigned to the underlying thread pool. * @param maxThreads - maximum number of thread can be created during execution. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7153bb72476a7..2f2734a389ff0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.shuffle.ShufflePartitionPairsWriter +import org.apache.spark.shuffle.{ShufflePartitionPairsWriter, ShuffleWriteMetricsReporter} import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} @@ -696,7 +696,8 @@ private[spark] class ExternalSorter[K, V, C]( def writePartitionedMapOutput( shuffleId: Int, mapId: Long, - mapOutputWriter: ShuffleMapOutputWriter): Unit = { + mapOutputWriter: ShuffleMapOutputWriter, + writeMetrics: ShuffleWriteMetricsReporter): Unit = { if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer @@ -713,7 +714,7 @@ private[spark] class ExternalSorter[K, V, C]( serializerManager, serInstance, blockId, - context.taskMetrics().shuffleWriteMetrics, + writeMetrics, if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) else null) while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(partitionPairsWriter) @@ -737,7 +738,7 @@ private[spark] class ExternalSorter[K, V, C]( serializerManager, serInstance, blockId, - context.taskMetrics().shuffleWriteMetrics, + writeMetrics, if (partitionChecksums.nonEmpty) partitionChecksums(id) else null) if (elements.hasNext) { for (elem <- elements) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 6815e47a198d9..4e307e35da8cd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -126,6 +126,17 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( this } + /** + * Check if a key exists at the provided position using object equality rather than + * cooperative equality. Otherwise, hash sets will mishandle values for which `==` + * and `equals` return different results, like 0.0/-0.0 and NaN/NaN. + * + * See: https://issues.apache.org/jira/browse/SPARK-45599 + */ + @annotation.nowarn + private def keyExistsAtPos(k: T, pos: Int) = + _data(pos) equals k + /** * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. @@ -146,8 +157,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( _bitset.set(pos) _size += 1 return pos | NONEXISTENCE_MASK - } else if (_data(pos) == k) { - // Found an existing key. + } else if (keyExistsAtPos(k, pos)) { return pos } else { // quadratic probing with values increase by 1, 2, 3, ... @@ -181,7 +191,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( while (true) { if (!_bitset.get(pos)) { return INVALID_POS - } else if (k == _data(pos)) { + } else if (keyExistsAtPos(k, pos)) { return pos } else { // quadratic probing with values increase by 1, 2, 3, ... diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index d3aa93549a83a..472d03baeae05 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -69,6 +69,7 @@ public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper { File tempDir; long[] partitionSizesInMergedFile; final LinkedList spillFilesCreated = new LinkedList<>(); + long totalSpilledDiskBytes = 0; SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf().set("spark.kryo.unsafe", "false")); @@ -96,6 +97,7 @@ public void setUp() throws Exception { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); + totalSpilledDiskBytes = 0; conf = new SparkConf() .set(package$.MODULE$.BUFFER_PAGESIZE().key(), "1m") .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false) @@ -160,7 +162,11 @@ public void setUp() throws Exception { when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); + File file = spy(File.createTempFile("spillFile", ".spill", tempDir)); + when(file.delete()).thenAnswer(inv -> { + totalSpilledDiskBytes += file.length(); + return inv.callRealMethod(); + }); spillFilesCreated.add(file); return Tuple2$.MODULE$.apply(blockId, file); }); @@ -284,6 +290,9 @@ public void writeWithoutSpilling() throws Exception { final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); + // Even if there is no spill, the sorter still writes its data to a spill file at the end, + // which will become the final shuffle file. + assertEquals(1, spillFilesCreated.size()); long sumOfPartitionSizes = 0; for (long size: partitionSizesInMergedFile) { @@ -305,7 +314,8 @@ public void writeWithoutSpilling() throws Exception { @Test public void writeChecksumFileWithoutSpill() throws Exception { - IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); + IndexShuffleBlockResolver blockResolver = + new IndexShuffleBlockResolver(conf, blockManager, Collections.emptyMap()); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); @@ -335,7 +345,8 @@ public void writeChecksumFileWithoutSpill() throws Exception { @Test public void writeChecksumFileWithSpill() throws Exception { - IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); + IndexShuffleBlockResolver blockResolver = + new IndexShuffleBlockResolver(conf, blockManager, Collections.emptyMap()); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); @@ -425,9 +436,8 @@ private void testMergingSpills( assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertTrue(taskMetrics.diskBytesSpilled() > 0L); - assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); assertTrue(taskMetrics.memoryBytesSpilled() > 0L); + assertEquals(totalSpilledDiskBytes, taskMetrics.diskBytesSpilled()); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @@ -517,9 +527,8 @@ public void writeEnoughDataToTriggerSpill() throws Exception { assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertTrue(taskMetrics.diskBytesSpilled() > 0L); - assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); assertTrue(taskMetrics.memoryBytesSpilled()> 0L); + assertEquals(totalSpilledDiskBytes, taskMetrics.diskBytesSpilled()); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @@ -550,9 +559,8 @@ private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exc assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertTrue(taskMetrics.diskBytesSpilled() > 0L); - assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); assertTrue(taskMetrics.memoryBytesSpilled()> 0L); + assertEquals(totalSpilledDiskBytes, taskMetrics.diskBytesSpilled()); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 9b70ccdf07e1b..a9d7e8a0f2eda 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -147,7 +147,7 @@ private[spark] object AccumulatorSuite { * Make an `AccumulableInfo` out of an `AccumulatorV2` with the intent to use the * info as an accumulator update. */ - def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) + def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfoUpdate /** * Run one or more Spark jobs and verify that in at least one job the peak execution memory diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 450ff01921a83..d6f925ddced92 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -1109,4 +1109,59 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { rpcEnv.shutdown() } } + + test( + "SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after removeOutputsByFilter" + ) { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + tracker.removeOutputsOnHost("hostA") + assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } + + test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after unregisterMapOutput") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + tracker.unregisterMapOutput(0, 0, BlockManagerId("exec-1", "hostA", 1000)) + assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } + + test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after registerMapOutput") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + try { + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(0, 1, 1) + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000), + Array(2L), 0)) + // Another task also finished working on partition 0. + tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-2", "hostB", 1000), + Array(2L), 1)) + assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 1) + } finally { + tracker.stop() + rpcEnv.shutdown() + } + } } diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index 797b650799eaf..795da65079d6e 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -123,7 +123,6 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { } override def afterAll(): Unit = { - tracker.stop() if (sc != null) { sc.stop() } diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index f5819b9508777..1163088c82aa8 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -302,6 +302,30 @@ abstract class SparkFunSuite } } + /** + * Sets all configurations specified in `pairs` in SparkEnv SparkConf, calls `f`, and then + * restores all configurations. + */ + protected def withSparkEnvConfs(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.getOption(key).isDefined) { + Some(conf.get(key)) + } else { + None + } + } + pairs.foreach { kv => conf.set(kv._1, kv._2) } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.set(key, value) + case (key, None) => conf.remove(key) + } + } + } + /** * Checks an exception with an error class against expected results. * @param exception The exception to check diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 0249cde54884b..a5f5eb21c68b3 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -253,8 +253,7 @@ class SparkThrowableSuite extends SparkFunSuite { | |Also see [SQLSTATE Codes](sql-error-conditions-sqlstates.html). | - |$sqlErrorParentDocContent - |""".stripMargin + |$sqlErrorParentDocContent""".stripMargin errors.filter(_._2.subClass.isDefined).foreach(error => { val name = error._1 @@ -316,7 +315,7 @@ class SparkThrowableSuite extends SparkFunSuite { } FileUtils.writeStringToFile( parentDocPath.toFile, - sqlErrorParentDoc + lineSeparator, + sqlErrorParentDoc, StandardCharsets.UTF_8) } } else { @@ -417,6 +416,7 @@ class SparkThrowableSuite extends SparkFunSuite { } catch { case e: SparkThrowable => assert(e.getErrorClass == null) + assert(!e.isInternalError) assert(e.getSqlState == null) case _: Throwable => // Should not end up here @@ -433,6 +433,7 @@ class SparkThrowableSuite extends SparkFunSuite { } catch { case e: SparkThrowable => assert(e.getErrorClass == "CANNOT_PARSE_DECIMAL") + assert(!e.isInternalError) assert(e.getSqlState == "22018") case _: Throwable => // Should not end up here diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 0817abbc6a328..9019ea484b3f3 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -140,16 +140,19 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont } sc.removeJobTag("tag1") + // takeAsync() across multiple partitions val thirdJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) - val thirdJobId = eventually(timeout(10.seconds)) { - thirdJobFuture.jobIds.head + val thirdJobIds = eventually(timeout(10.seconds)) { + // Wait for the two jobs triggered by takeAsync + thirdJobFuture.jobIds.size should be(2) + thirdJobFuture.jobIds } eventually(timeout(10.seconds)) { sc.statusTracker.getJobIdsForTag("tag1").toSet should be ( Set(firstJobId, secondJobId)) sc.statusTracker.getJobIdsForTag("tag2").toSet should be ( - Set(secondJobId, thirdJobId)) + Set(secondJobId) ++ thirdJobIds) } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitTestUtils.scala index 2ab2e17df03a8..932e972374cae 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitTestUtils.scala @@ -18,8 +18,6 @@ package org.apache.spark.deploy import java.io.File -import java.sql.Timestamp -import java.util.Date import scala.collection.mutable.ArrayBuffer @@ -69,17 +67,8 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { env.put("SPARK_HOME", sparkHome) def captureOutput(source: String)(line: String): Unit = { - // This test suite has some weird behaviors when executed on Jenkins: - // - // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a - // timestamp to provide more diagnosis information. - // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print - // them out for debugging purposes. - val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" - // scalastyle:off println - println(logLine) - // scalastyle:on println - history += logLine + logInfo(s"$source> $line") + history += line } val process = builder.start() diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala index efb8393403043..f34f792881f90 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala @@ -229,6 +229,37 @@ class SingleFileEventLogFileReaderSuite extends EventLogFileReadersSuite { } class RollingEventLogFilesReaderSuite extends EventLogFileReadersSuite { + test("SPARK-46012: appStatus file should exist") { + withTempDir { dir => + val appId = getUniqueApplicationId + val attemptId = None + + val conf = getLoggingConf(testDirPath) + conf.set(EVENT_LOG_ENABLE_ROLLING, true) + conf.set(EVENT_LOG_ROLLING_MAX_FILE_SIZE.key, "10m") + + val writer = createWriter(appId, attemptId, testDirPath.toUri, conf, + SparkHadoopUtil.get.newConfiguration(conf)) + + writer.start() + val dummyStr = "dummy" * 1024 + writeTestEvents(writer, dummyStr, 1024 * 1024 * 20) + writer.stop() + + // Verify a healthy rolling event log directory + val logPathCompleted = getCurrentLogPath(writer.logPath, isCompleted = true) + val readerOpt = EventLogFileReader(fileSystem, new Path(logPathCompleted)) + assert(readerOpt.get.isInstanceOf[RollingEventLogFilesFileReader]) + assert(readerOpt.get.listEventLogFiles.length === 3) + + // Make unhealthy rolling event directory by removing appStatus file. + val appStatusFile = fileSystem.listStatus(new Path(logPathCompleted)) + .find(RollingEventLogFilesWriter.isAppStatusFile).get.getPath + fileSystem.delete(appStatusFile, false) + assert(EventLogFileReader(fileSystem, new Path(logPathCompleted)).isEmpty) + } + } + allCodecs.foreach { codecShortName => test(s"rolling event log files - codec $codecShortName") { val appId = getUniqueApplicationId diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 1cec863b1e7f9..37874de987662 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.master +import java.net.{HttpURLConnection, URL} import java.util.Date import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger @@ -325,6 +326,26 @@ class MasterSuite extends SparkFunSuite } } + test("SPARK-46888: master should reject worker kill request if decommision is disabled") { + implicit val formats = org.json4s.DefaultFormats + val conf = new SparkConf() + .set(DECOMMISSION_ENABLED, false) + .set(MASTER_UI_DECOMMISSION_ALLOW_MODE, "ALLOW") + val localCluster = LocalSparkCluster(1, 1, 512, conf) + localCluster.start() + val masterUrl = s"http://${Utils.localHostNameForURI()}:${localCluster.masterWebUIPort}" + try { + eventually(timeout(30.seconds), interval(100.milliseconds)) { + val url = new URL(s"$masterUrl/workers/kill/?host=${Utils.localHostNameForURI()}") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + assert(conn.getResponseCode === 405) + } + } finally { + localCluster.stop() + } + } + test("master/worker web ui available") { implicit val formats = org.json4s.DefaultFormats val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala new file mode 100644 index 0000000000000..e1edef8f4155c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date +import javax.servlet.http.HttpServletRequest + +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master.{ApplicationInfo, ApplicationState, Master} +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEndpointRef + +class ApplicationPageSuite extends SparkFunSuite { + + private val master = mock(classOf[Master]) + + private val rp = new ResourceProfile(Map.empty, Map.empty) + private val desc = ApplicationDescription("name", Some(4), null, "appUiUrl", rp) + private val descWithoutUI = ApplicationDescription("name", Some(4), null, "", rp) + private val appFinished = new ApplicationInfo(0, "app-finished", desc, new Date, null, 1) + appFinished.markFinished(ApplicationState.FINISHED) + private val appLive = new ApplicationInfo(0, "app-live", desc, new Date, null, 1) + private val appLiveWithoutUI = + new ApplicationInfo(0, "app-live-without-ui", descWithoutUI, new Date, null, 1) + + private val state = mock(classOf[MasterStateResponse]) + when(state.completedApps).thenReturn(Array(appFinished)) + when(state.activeApps).thenReturn(Array(appLive, appLiveWithoutUI)) + + private val rpc = mock(classOf[RpcEndpointRef]) + when(rpc.askSync[MasterStateResponse](RequestMasterState)).thenReturn(state) + + private val masterWebUI = mock(classOf[MasterWebUI]) + when(masterWebUI.master).thenReturn(master) + when(masterWebUI.masterEndpointRef).thenReturn(rpc) + + test("SPARK-45774: Application Detail UI") { + val request = mock(classOf[HttpServletRequest]) + when(request.getParameter("appId")).thenReturn("app-live") + + val result = new ApplicationPage(masterWebUI).render(request).toString() + assert(result.contains("Application Detail UI")) + assert(!result.contains("Application History UI")) + } + + test("SPARK-50021: Application Detail UI is empty when spark.ui.enabled=false") { + val request = mock(classOf[HttpServletRequest]) + when(request.getParameter("appId")).thenReturn("app-live-without-ui") + + val result = new ApplicationPage(masterWebUI).render(request).toString() + assert(result.contains("Application UI: Disabled")) + assert(!result.contains("Application History UI")) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index 024511189accc..bda3309ad8208 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -30,12 +30,14 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, KillDriverResponse, RequestKillDriver} import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ +import org.apache.spark.internal.config.DECOMMISSION_ENABLED import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.util.Utils class MasterWebUISuite extends SparkFunSuite { + import MasterWebUISuite._ - val conf = new SparkConf() + val conf = new SparkConf().set(DECOMMISSION_ENABLED, true) val securityMgr = new SecurityManager(conf) val rpcEnv = mock(classOf[RpcEnv]) val master = mock(classOf[Master]) @@ -112,12 +114,14 @@ class MasterWebUISuite extends SparkFunSuite { private def convPostDataToString(data: Map[String, String]): String = { convPostDataToString(data.toSeq) } +} +object MasterWebUISuite { /** * Send an HTTP request to the given URL using the method and the body specified. * Return the connection object. */ - private def sendHttpRequest( + private[ui] def sendHttpRequest( url: String, method: String, body: String = ""): HttpURLConnection = { diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/ReadOnlyMasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/ReadOnlyMasterWebUISuite.scala new file mode 100644 index 0000000000000..9fd5431418aa2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/ReadOnlyMasterWebUISuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date +import javax.servlet.http.HttpServletResponse.SC_OK + +import scala.io.Source + +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master._ +import org.apache.spark.deploy.master.ui.MasterWebUISuite._ +import org.apache.spark.internal.config.DECOMMISSION_ENABLED +import org.apache.spark.internal.config.UI.UI_KILL_ENABLED +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.util.Utils + +class ReadOnlyMasterWebUISuite extends SparkFunSuite { + + import org.apache.spark.deploy.DeployTestUtils._ + + val conf = new SparkConf() + .set(UI_KILL_ENABLED, false) + .set(DECOMMISSION_ENABLED, false) + val securityMgr = new SecurityManager(conf) + val rpcEnv = mock(classOf[RpcEnv]) + val master = mock(classOf[Master]) + val masterEndpointRef = mock(classOf[RpcEndpointRef]) + when(master.securityMgr).thenReturn(securityMgr) + when(master.conf).thenReturn(conf) + when(master.rpcEnv).thenReturn(rpcEnv) + when(master.self).thenReturn(masterEndpointRef) + val desc1 = createAppDesc().copy(name = "WithUI") + val desc2 = desc1.copy(name = "WithoutUI", appUiUrl = "") + val app1 = new ApplicationInfo(new Date().getTime, "app1", desc1, new Date(), null, Int.MaxValue) + val app2 = new ApplicationInfo(new Date().getTime, "app2", desc2, new Date(), null, Int.MaxValue) + val state = new MasterStateResponse( + "host", 8080, None, Array.empty, Array(app1, app2), Array.empty, + Array.empty, Array.empty, RecoveryState.ALIVE) + when(masterEndpointRef.askSync[MasterStateResponse](RequestMasterState)).thenReturn(state) + val masterWebUI = new MasterWebUI(master, 0) + + override def beforeAll(): Unit = { + super.beforeAll() + masterWebUI.bind() + } + + override def afterAll(): Unit = { + try { + masterWebUI.stop() + } finally { + super.afterAll() + } + } + + test("SPARK-50022: Fix 'MasterPage' to hide App UI links when UI is disabled") { + val url = s"http://${Utils.localHostNameForURI()}:${masterWebUI.boundPort}/" + val conn = sendHttpRequest(url, "GET") + assert(conn.getResponseCode === SC_OK) + val result = Source.fromInputStream(conn.getInputStream).mkString + assert(result.contains("WithUI")) + assert(result.contains(" WithoutUI\n")) + } +} diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 0dcc7c7f9b4cf..909d605442575 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -302,7 +302,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf)) assert(backend.taskResources.isEmpty) - val taskId = 1000000 + val taskId = 1000000L // We don't really verify the data, just pass it around. val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) val taskDescription = new TaskDescription(taskId, 2, "1", "TASK 1000000", 19, @@ -339,14 +339,14 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite backend.self.send(LaunchTask(new SerializableBuffer(serializedTaskDescription))) eventually(timeout(10.seconds)) { assert(backend.taskResources.size == 1) - val resources = backend.taskResources(taskId) + val resources = backend.taskResources.get(taskId) assert(resources(GPU).addresses sameElements Array("0", "1")) } // Update the status of a running task shall not affect `taskResources` map. backend.statusUpdate(taskId, TaskState.RUNNING, data) assert(backend.taskResources.size == 1) - val resources = backend.taskResources(taskId) + val resources = backend.taskResources.get(taskId) assert(resources(GPU).addresses sameElements Array("0", "1")) // Update the status of a finished task shall remove the entry from `taskResources` map. diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 5e66ca962ea2c..c4ef45658ae9c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -201,10 +201,10 @@ class AsyncRDDActionsSuite extends SparkFunSuite with TimeLimits { test("FutureAction result, timeout") { val f = sc.parallelize(1 to 100, 4) - .mapPartitions(itr => { Thread.sleep(20); itr }) + .mapPartitions(itr => { Thread.sleep(200); itr }) .countAsync() intercept[TimeoutException] { - ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) + ThreadUtils.awaitResult(f, Duration(2, "milliseconds")) } } diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala index e97d5c7883aa8..7149267583bc5 100644 --- a/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala @@ -126,18 +126,34 @@ class ResourceProfileManagerSuite extends SparkFunSuite { val defaultProf = rpmanager.defaultResourceProfile assert(rpmanager.isSupported(defaultProf)) - // task resource profile. + // Standalone: supports task resource profile. val gpuTaskReq = new TaskResourceRequests().resource("gpu", 1) val taskProf = new TaskResourceProfile(gpuTaskReq.requests) assert(rpmanager.isSupported(taskProf)) + // Local: doesn't support task resource profile. conf.setMaster("local") rpmanager = new ResourceProfileManager(conf, listenerBus) val error = intercept[SparkException] { rpmanager.isSupported(taskProf) }.getMessage - assert(error === "TaskResourceProfiles are only supported for Standalone " + - "cluster for now when dynamic allocation is disabled.") + assert(error === "TaskResourceProfiles are only supported for Standalone, " + + "Yarn and Kubernetes cluster for now when dynamic allocation is disabled.") + + // Local cluster: supports task resource profile. + conf.setMaster("local-cluster[1, 1, 1024]") + rpmanager = new ResourceProfileManager(conf, listenerBus) + assert(rpmanager.isSupported(taskProf)) + + // Yarn: supports task resource profile. + conf.setMaster("yarn") + rpmanager = new ResourceProfileManager(conf, listenerBus) + assert(rpmanager.isSupported(taskProf)) + + // K8s: supports task resource profile. + conf.setMaster("k8s://foo") + rpmanager = new ResourceProfileManager(conf, listenerBus) + assert(rpmanager.isSupported(taskProf)) } test("isSupported task resource profiles with dynamic allocation enabled") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 26cd5374fa09c..f370a8e02391e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -101,7 +101,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) + Thread.sleep(Random.nextInt(500)) context.barrier() val time1 = System.currentTimeMillis() // Sleep for a random time before global sync. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c7e4994e328f4..1818bf9b152d3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -48,7 +48,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster} -import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, Clock, LongAccumulator, SystemClock, Utils} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, Clock, LongAccumulator, SystemClock, ThreadUtils, Utils} class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -594,6 +594,42 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assertDataStructuresEmpty() } + // Note that this test is NOT perfectly reproducible when there is a deadlock as it uses + // Thread.sleep, but it should never fail / flake when there is no deadlock. + // If this test starts to flake, this shows that there is a deadlock! + test("No Deadlock between getCacheLocs and CoalescedRDD") { + val rdd = sc.parallelize(1 to 10, numSlices = 10) + val coalescedRDD = rdd.coalesce(2) + val executionContext = ThreadUtils.newDaemonFixedThreadPool( + nThreads = 2, "test-getCacheLocs") + // Used to only make progress on getCacheLocs after we acquired the lock to the RDD. + val rddLock = new java.util.concurrent.Semaphore(0) + val partitionsFuture = executionContext.submit(new Runnable { + override def run(): Unit = { + coalescedRDD.stateLock.synchronized { + rddLock.release(1) + // Try to access the partitions of the coalescedRDD. This will cause a call to + // getCacheLocs internally. + Thread.sleep(5000) + coalescedRDD.partitions + } + } + }) + val getCacheLocsFuture = executionContext.submit(new Runnable { + override def run(): Unit = { + rddLock.acquire() + // Access the cache locations. + // If the partition location cache is locked before the stateLock is locked, + // we'll run into a deadlock. + sc.dagScheduler.getCacheLocs(coalescedRDD) + } + }) + // If any of the futures throw a TimeOutException, this shows that there is a deadlock between + // getCacheLocs and accessing partitions of an RDD. + getCacheLocsFuture.get(120, TimeUnit.SECONDS) + partitionsFuture.get(120, TimeUnit.SECONDS) + } + test("All shuffle files on the storage endpoint should be cleaned up when it is lost") { conf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true") @@ -3041,6 +3077,27 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti (shuffleId1, shuffleId2) } + private def constructTwoIndeterminateStage(): (Int, Int) = { + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker, + indeterminate = true) + + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + completeShuffleMapStageSuccessfully(0, 0, 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + (shuffleId1, shuffleId2) + } + test("SPARK-25341: abort stage while using old fetch protocol") { conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true") // Construct the scenario of indeterminate stage fetch failed. @@ -3099,6 +3156,92 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assertDataStructuresEmpty() } + test("SPARK-45182: Ignore task completion from old stage after retrying indeterminate stages") { + val (shuffleId1, shuffleId2) = constructTwoIndeterminateStage() + + // shuffleMapStage0 -> shuffleId1 -> shuffleMapStage1 -> shuffleId2 -> resultStage + val shuffleMapStage1 = scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage] + val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage] + + // Shuffle map stage 0 is done + assert(mapOutputTracker.findMissingPartitions(shuffleId1) == Some(Seq.empty)) + // Shuffle map stage 1 is still waiting for its 2 tasks to complete + assert(mapOutputTracker.findMissingPartitions(shuffleId2) == Some(Seq(0, 1))) + // The result stage is still waiting for its 2 tasks to complete + assert(resultStage.findMissingPartitions() == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // The first task of the shuffle map stage 1 fails with fetch failure + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"), + null)) + + // Both the stages should have been resubmitted + val newFailedStages = scheduler.failedStages.toSeq + assert(newFailedStages.map(_.id) == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // Since shuffleId1 is indeterminate, all tasks of shuffle map stage 0 should be ran + assert(taskSets(2).stageId == 0) + assert(taskSets(2).stageAttemptId == 1) + assert(taskSets(2).tasks.length == 2) + + // Complete the re-attempt of shuffle map stage 0 + completeShuffleMapStageSuccessfully(0, 1, 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + // Since shuffleId2 is indeterminate, all tasks of shuffle map stage 1 should be ran + assert(taskSets(3).stageId == 1) + assert(taskSets(3).stageAttemptId == 1) + assert(taskSets(3).tasks.length == 2) + + // The first task of the shuffle map stage 1 from 2nd attempt succeeds + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostB", + 2))) + + // The second task of the shuffle map stage 1 from 1st attempt succeeds + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), + Success, + makeMapStatus("hostC", + 2))) + + // Above task completion should not mark the partition 1 complete from 2nd attempt + assert(!tasksMarkedAsCompleted.contains(taskSets(3).tasks(1))) + + // This task completion should get ignored and partition 1 should be missing + // for shuffle map stage 1 + assert(mapOutputTracker.findMissingPartitions(shuffleId2) == Some(Seq(1))) + + // The second task of the shuffle map stage 1 from 2nd attempt succeeds + runEvent(makeCompletionEvent( + taskSets(3).tasks(1), + Success, + makeMapStatus("hostD", + 2))) + + // The shuffle map stage 1 should be done + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // The shuffle map outputs for shuffleId1 should be from latest attempt of shuffle map stage 1 + assert(mapOutputTracker.getMapLocation(shuffleMapStage1.shuffleDep, 0, 2) + === Seq("hostB", "hostD")) + + // Complete result stage + complete(taskSets(4), Seq((Success, 11), (Success, 12))) + + // Job successfully ended + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + test("SPARK-25341: continuous indeterminate stage roll back") { // shuffleMapRdd1/2/3 are all indeterminate. val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index 45da750768fa9..7d063c3b3ac53 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} +import org.scalatest.time.{Seconds, Span} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite, TaskContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext} /** * Integration tests for the OutputCommitCoordinator. @@ -44,15 +45,13 @@ class OutputCommitCoordinatorIntegrationSuite sc = new SparkContext("local[2, 4]", "test", conf) } - test("SPARK-39195: exception thrown in OutputCommitter.commitTask()") { + test("exception thrown in OutputCommitter.commitTask()") { // Regression test for SPARK-10381 - val e = intercept[SparkException] { + failAfter(Span(60, Seconds)) { withTempDir { tempDir => sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") } - }.getCause.getMessage - assert(e.contains("failed; but task commit success, data duplication may happen.") && - e.contains("Intentional exception")) + } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 44dc9a5f97dab..d84892be14af5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -87,12 +87,11 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { outputCommitCoordinator = - spy[OutputCommitCoordinator]( - new OutputCommitCoordinator(conf, isDriver = true, Option(this))) + spy[OutputCommitCoordinator](new OutputCommitCoordinator(conf, isDriver = true)) // Use Mockito.spy() to maintain the default infrastructure everywhere else. // This mocking allows us to control the coordinator responses in test cases. SparkEnv.createDriverEnv(conf, isLocal, listenerBus, - SparkContext.numDriverCores(master), this, Some(outputCommitCoordinator)) + SparkContext.numDriverCores(master), Some(outputCommitCoordinator)) } } // Use Mockito.spy() to maintain the default infrastructure everywhere else @@ -190,9 +189,12 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) - // A new task should not be allowed to become stage failed because of potential data duplication - assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + // A new task should now be allowed to become the authorized committer + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, nonAuthorizedCommitter + 2)) + // There can only be one authorized committer + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 3)) } test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { @@ -226,8 +228,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit, - // then fail the 1st attempt and since stage failed because of potential data duplication, - // make sure fail the 4th attempt. + // then fail the 1st attempt and make sure the 4th one can commit again. stage += 1 outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) @@ -236,9 +237,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt)) outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, ExecutorLostFailure("0", exitCausedByApp = true, None)) - // A new task should not be allowed to become the authorized committer since stage failed - // because of potential data duplication - assert(!outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) + assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) } test("SPARK-24589: Make sure stage state is cleaned up") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 54a42c1a66184..a5c2cbf52aafd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -669,6 +669,16 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(invocationOrder === Seq("C", "B", "A", "D")) } + test("SPARK-46480: Add isFailed in TaskContext") { + val context = TaskContext.empty() + var isFailed = false + context.addTaskCompletionListener[Unit] { context => + isFailed = context.isFailed() + } + context.markTaskFailed(new RuntimeException()) + context.markTaskCompleted(None) + assert(isFailed) + } } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 9e52b5e15143b..99402abb16cac 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -85,6 +85,7 @@ class SortShuffleWriterSuite shuffleHandle, mapId = 1, context, + context.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) writer.write(Iterator.empty) writer.stop(success = true) @@ -102,6 +103,7 @@ class SortShuffleWriterSuite shuffleHandle, mapId = 2, context, + context.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) writer.write(records.iterator) writer.stop(success = true) @@ -158,6 +160,7 @@ class SortShuffleWriterSuite shuffleHandle, mapId = 0, context, + context.taskMetrics().shuffleWriteMetrics, new LocalDiskShuffleExecutorComponents( conf, shuffleBlockResolver._blockManager, shuffleBlockResolver)) writer.write(records.iterator) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 3708f0aa67223..f133a38269d71 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -166,6 +166,20 @@ class BlockInfoManagerSuite extends SparkFunSuite { assert(blockInfoManager.get("block").get.readerCount === 1) } + test("lockNewBlockForWriting should not block when keepReadLock is false") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val lock1Future = Future { + withTaskId(1) { + blockInfoManager.lockNewBlockForWriting("block", newBlockInfo(), false) + } + } + + assert(!ThreadUtils.awaitResult(lock1Future, 1.seconds)) + assert(blockInfoManager.get("block").get.readerCount === 0) + } + test("read locks are reentrant") { withTaskId(1) { assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index d9d2e6102f120..2ba348222f7be 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.storage +import java.io.File import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Semaphore, TimeUnit} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ +import org.apache.commons.io.FileUtils import org.scalatest.concurrent.Eventually import org.apache.spark._ @@ -352,4 +354,78 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS import scala.language.reflectiveCalls assert(listener.removeReasonValidated) } + + test("SPARK-46957: Migrated shuffle files should be able to cleanup from executor") { + + val sparkTempDir = System.getProperty("java.io.tmpdir") + + def shuffleFiles: Seq[File] = { + FileUtils + .listFiles(new File(sparkTempDir), Array("data", "index"), true) + .asScala + .toSeq + } + + val existingShuffleFiles = shuffleFiles + + val conf = new SparkConf() + .setAppName("SPARK-46957") + .setMaster("local-cluster[2,1,1024]") + .set(config.DECOMMISSION_ENABLED, true) + .set(config.STORAGE_DECOMMISSION_ENABLED, true) + .set(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED, true) + sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val shuffleBlockUpdates = new ArrayBuffer[BlockId]() + var isDecommissionedExecutorRemoved = false + val execToDecommission = sc.getExecutorIds().head + sc.addSparkListener(new SparkListener { + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + if (blockUpdated.blockUpdatedInfo.blockId.isShuffle) { + shuffleBlockUpdates += blockUpdated.blockUpdatedInfo.blockId + } + } + + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + assert(execToDecommission === executorRemoved.executorId) + isDecommissionedExecutorRemoved = true + } + }) + + // Run a job to create shuffle data + val result = sc.parallelize(1 to 1000, 10) + .map { i => (i % 2, i) } + .reduceByKey(_ + _).collect() + + assert(result.head === (0, 250500)) + assert(result.tail.head === (1, 250000)) + sc.schedulerBackend + .asInstanceOf[StandaloneSchedulerBackend] + .decommissionExecutor( + execToDecommission, + ExecutorDecommissionInfo("test", None), + adjustTargetNumExecutors = true + ) + + eventually(timeout(1.minute), interval(10.milliseconds)) { + assert(isDecommissionedExecutorRemoved) + // Ensure there are shuffle data have been migrated + assert(shuffleBlockUpdates.size >= 2) + } + + val shuffleId = shuffleBlockUpdates + .find(_.isInstanceOf[ShuffleIndexBlockId]) + .map(_.asInstanceOf[ShuffleIndexBlockId].shuffleId) + .get + + val newShuffleFiles = shuffleFiles.diff(existingShuffleFiles) + assert(newShuffleFiles.size >= shuffleBlockUpdates.size) + + // Remove the shuffle data + sc.shuffleDriverComponents.removeShuffle(shuffleId, true) + + eventually(timeout(1.minute), interval(10.milliseconds)) { + assert(newShuffleFiles.intersect(shuffleFiles).isEmpty) + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 38a669bc85744..29526684c3e9f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -38,6 +38,8 @@ import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.shuffle.ExternalBlockStoreClient +import org.apache.spark.network.util.{MapConfigProvider, TransportConf} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} @@ -295,6 +297,40 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } } + test("Test block location after replication with SHUFFLE_SERVICE_FETCH_RDD_ENABLED enabled") { + val newConf = conf.clone() + newConf.set(SHUFFLE_SERVICE_ENABLED, true) + newConf.set(SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true) + val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() + val shuffleClient = Some(new ExternalBlockStoreClient( + new TransportConf("shuffle", MapConfigProvider.EMPTY), + null, false, 5000)) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager-2", + new BlockManagerMasterEndpoint(rpcEnv, true, newConf, + new LiveListenerBus(newConf), shuffleClient, blockManagerInfo, mapOutputTracker, + sc.env.shuffleManager, isDriver = true)), + rpcEnv.setupEndpoint("blockmanagerHeartbeat-2", + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), newConf, true) + + val shuffleServicePort = newConf.get(SHUFFLE_SERVICE_PORT) + val store1 = makeBlockManager(10000, "host-1") + val store2 = makeBlockManager(10000, "host-2") + assert(master.getPeers(store1.blockManagerId).toSet === Set(store2.blockManagerId)) + + val blockId = RDDBlockId(1, 2) + val message = new Array[Byte](1000) + + // if SHUFFLE_SERVICE_FETCH_RDD_ENABLED is enabled, then shuffle port should be present. + store1.putSingle(blockId, message, StorageLevel.DISK_ONLY) + assert(master.getLocations(blockId).contains( + BlockManagerId("host-1", "localhost", shuffleServicePort, None))) + + // after block is removed, shuffle port should be removed. + store1.removeBlock(blockId, true) + assert(!master.getLocations(blockId).contains( + BlockManagerId("host-1", "localhost", shuffleServicePort, None))) + } + test("block replication - addition and deletion of block managers") { val blockSize = 1000 val storeSize = 10000 diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd66dc2c5fb0..728e3a252b7a1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,7 +33,7 @@ import scala.reflect.classTag import com.esotericsoftware.kryo.KryoException import org.apache.commons.lang3.RandomUtils import org.mockito.{ArgumentCaptor, ArgumentMatchers => mc} -import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify, when} +import org.mockito.Mockito.{atLeastOnce, doAnswer, mock, never, spy, times, verify, when} import org.scalatest.PrivateMethodTester import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ @@ -666,7 +666,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe removedFromMemory: Boolean, removedFromDisk: Boolean): Unit = { def assertSizeReported(captor: ArgumentCaptor[Long], expectRemoved: Boolean): Unit = { - assert(captor.getAllValues().size() === 1) + assert(captor.getAllValues().size() >= 1) if (expectRemoved) { assert(captor.getValue() > 0) } else { @@ -676,15 +676,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe val memSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] val diskSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] - verify(master).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), - mc.eq(StorageLevel.NONE), memSizeCaptor.capture(), diskSizeCaptor.capture()) + val storageLevelCaptor = + ArgumentCaptor.forClass(classOf[StorageLevel]).asInstanceOf[ArgumentCaptor[StorageLevel]] + verify(master, atLeastOnce()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), + storageLevelCaptor.capture(), memSizeCaptor.capture(), diskSizeCaptor.capture()) assertSizeReported(memSizeCaptor, removedFromMemory) assertSizeReported(diskSizeCaptor, removedFromDisk) + assert(storageLevelCaptor.getValue.replication == 0) } private def assertUpdateBlockInfoNotReported(store: BlockManager, blockId: BlockId): Unit = { verify(master, never()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), - mc.eq(StorageLevel.NONE), mc.anyInt(), mc.anyInt()) + mc.any[StorageLevel](), mc.anyInt(), mc.anyInt()) } test("reregistration on heart beat") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 70a57eed07acd..4352436c872fe 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -16,11 +16,14 @@ */ package org.apache.spark.storage -import java.io.File +import java.io.{File, InputStream, OutputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.serializer.{DeserializationStream, JavaSerializer, SerializationStream, Serializer, SerializerInstance, SerializerManager} import org.apache.spark.util.Utils class DiskBlockObjectWriterSuite extends SparkFunSuite { @@ -43,10 +46,14 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite { private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = { val file = new File(tempDir, "somefile") val conf = new SparkConf() - val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) + val serializerManager = new CustomSerializerManager(new JavaSerializer(conf), conf, None) val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true, + file, + serializerManager, + new CustomJavaSerializer(new SparkConf()).newInstance(), + 1024, + true, writeMetrics) (writer, file, writeMetrics) } @@ -196,9 +203,76 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite { for (i <- 1 to 500) { writer.write(i, i) } + + val bs = writer.getSerializerWrappedStream.asInstanceOf[OutputStreamWithCloseDetecting] + val objOut = writer.getSerializationStream.asInstanceOf[SerializationStreamWithCloseDetecting] + writer.closeAndDelete() assert(!file.exists()) assert(writeMetrics.bytesWritten == 0) assert(writeMetrics.recordsWritten == 0) + assert(bs.isClosed) + assert(objOut.isClosed) + } +} + +trait CloseDetecting { + var isClosed = false +} + +class OutputStreamWithCloseDetecting(outputStream: OutputStream) + extends OutputStream + with CloseDetecting { + override def write(b: Int): Unit = outputStream.write(b) + + override def close(): Unit = { + isClosed = true + outputStream.close() + } +} + +class CustomSerializerManager( + defaultSerializer: Serializer, + conf: SparkConf, + encryptionKey: Option[Array[Byte]]) + extends SerializerManager(defaultSerializer, conf, encryptionKey) { + override def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = { + new OutputStreamWithCloseDetecting(wrapForCompression(blockId, wrapForEncryption(s))) + } +} + +class CustomJavaSerializer(conf: SparkConf) extends JavaSerializer(conf) { + + override def newInstance(): SerializerInstance = { + new CustomJavaSerializerInstance(super.newInstance()) } } + +class SerializationStreamWithCloseDetecting(serializationStream: SerializationStream) + extends SerializationStream with CloseDetecting { + + override def close(): Unit = { + isClosed = true + serializationStream.close() + } + + override def writeObject[T: ClassTag](t: T): SerializationStream = + serializationStream.writeObject(t) + + override def flush(): Unit = serializationStream.flush() +} + +class CustomJavaSerializerInstance(instance: SerializerInstance) extends SerializerInstance { + override def serializeStream(s: OutputStream): SerializationStream = + new SerializationStreamWithCloseDetecting(instance.serializeStream(s)) + + override def serialize[T: ClassTag](t: T): ByteBuffer = instance.serialize(t) + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = instance.deserialize(bytes) + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + instance.deserialize(bytes, loader) + + override def deserializeStream(s: InputStream): DeserializationStream = + instance.deserializeStream(s) +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index af37a72c9e3f8..a9902cb4ccb4c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -182,6 +182,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blocksByAddress: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], taskContext: Option[TaskContext] = None, streamWrapperLimitSize: Option[Long] = None, + corruptAtAvailableReset: Boolean = false, blockManager: Option[BlockManager] = None, maxBytesInFlight: Long = Long.MaxValue, maxReqsInFlight: Int = Int.MaxValue, @@ -201,7 +202,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager.getOrElse(createMockBlockManager()), mapOutputTracker, blocksByAddress.iterator, - (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in), + (_, in) => { + val limited = streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in) + if (corruptAtAvailableReset) { + new CorruptAvailableResetStream(limited) + } else { + limited + } + }, maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, @@ -712,6 +720,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT corruptBuffer } + private class CorruptAvailableResetStream(in: InputStream) extends InputStream { + override def read(): Int = in.read() + + override def read(dest: Array[Byte], off: Int, len: Int): Int = in.read(dest, off, len) + + override def available(): Int = throw new IOException("corrupt at available") + + override def reset(): Unit = throw new IOException("corrupt at reset") + } + private class CorruptStream(corruptAt: Long = 0L) extends InputStream { var pos = 0 var closed = false @@ -1879,4 +1897,75 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) verifyLocalBlocksFromFallback(iterator) } + + test("SPARK-45678: retry corrupt blocks on available() and reset()") { + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer() + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, createMockManagedBuffer()) + sem.release() + } + } + + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100), + detectCorruptUseExtraMemory = false, // Don't use `ChunkedByteBufferInputStream`. + corruptAtAvailableReset = true, + checksumEnabled = false + ) + + sem.acquire() + + val (id1, stream) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 0)) + + val err1 = intercept[FetchFailedException] { + stream.available() + } + + assert(err1.getMessage.contains("corrupt at available")) + + val err2 = intercept[FetchFailedException] { + stream.reset() + } + + assert(err2.getMessage.contains("corrupt at reset")) + } + + test("SPARK-43242: Fix throw 'Unexpected type of BlockId' in shuffle corruption diagnose") { + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockBatchId(0, 0, 0, 3) -> createMockManagedBuffer()) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + listener.onBlockFetchSuccess(ShuffleBlockBatchId(0, 0, 0, 3).toString, mockCorruptBuffer()) + } + + val logAppender = new LogAppender("diagnose corruption") + withLogAppender(logAppender) { + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100) + ) + intercept[FetchFailedException](iterator.next()) + verify(transfer, times(2)) + .fetchBlocks(any(), any(), any(), any(), any(), any()) + assert(logAppender.loggingEvents.count( + _.getMessage.getFormattedMessage.contains("Start corruption diagnosis")) === 1) + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("shuffle_0_0_0_3 is corrupted " + + "but corruption diagnosis is skipped due to lack of " + + "shuffle checksum support for ShuffleBlockBatchId"))) + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index aecd25f6c8dea..5586badd668dd 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -61,6 +61,20 @@ class UIUtilsSuite extends SparkFunSuite { errorMsg = "Base URL should be prepended to html links", plainText = false ) + + verify( + """""", + {""""""}, + "Non href attributes should make the description be treated as a string instead of HTML", + plainText = false + ) + + verify( + """""", + {""""""}, + "Non href attributes should make the description be treated as a string instead of HTML", + plainText = false + ) } test("makeDescription(plainText = true)") { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 7923e81949db6..1a7bfc64c23c7 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1093,6 +1093,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { // Set some secret keys val secretKeys = Seq( "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD", + "spark.hadoop.fs.s3.awsAccessKeyId", "spark.hadoop.fs.s3a.access.key", "spark.my.password", "spark.my.sECreT") diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 1af99e9017c9c..f7b026ab565f0 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -249,4 +249,34 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { map(null) = null assert(map.get(null) === Some(null)) } + + test("SPARK-45599: 0.0 and -0.0 should count distinctly; NaNs should count together") { + // Exactly these elements provided in roughly this order trigger a condition where lookups of + // 0.0 and -0.0 in the bitset happen to collide, causing their counts to be merged incorrectly + // and inconsistently if `==` is used to check for key equality. + val spark45599Repro = Seq( + Double.NaN, + 2.0, + 168.0, + Double.NaN, + Double.NaN, + -0.0, + 153.0, + 0.0 + ) + + val map1 = new OpenHashMap[Double, Int]() + spark45599Repro.foreach(map1.changeValue(_, 1, {_ + 1})) + assert(map1(0.0) == 1) + assert(map1(-0.0) == 1) + assert(map1(Double.NaN) == 3) + + val map2 = new OpenHashMap[Double, Int]() + // Simply changing the order in which the elements are added to the map should not change the + // counts for 0.0 and -0.0. + spark45599Repro.reverse.foreach(map2.changeValue(_, 1, {_ + 1})) + assert(map2(0.0) == 1) + assert(map2(-0.0) == 1) + assert(map2(Double.NaN) == 3) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 89a308556d5df..0bc8aa067f57a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -269,4 +269,43 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(pos1 == pos2) } } + + test("SPARK-45599: 0.0 and -0.0 are equal but not the same") { + // Therefore, 0.0 and -0.0 should get separate entries in the hash set. + // + // Exactly these elements provided in roughly this order will trigger the following scenario: + // When probing the bitset in `getPos(-0.0)`, the loop will happen upon the entry for 0.0. + // In the old logic pre-SPARK-45599, the loop will find that the bit is set and, because + // -0.0 == 0.0, it will think that's the position of -0.0. But in reality this is the position + // of 0.0. So -0.0 and 0.0 will be stored at different positions, but `getPos()` will return + // the same position for them. This can cause users of OpenHashSet, like OpenHashMap, to + // return the wrong value for a key based on whether or not this bitset lookup collision + // happens. + val spark45599Repro = Seq( + Double.NaN, + 2.0, + 168.0, + Double.NaN, + Double.NaN, + -0.0, + 153.0, + 0.0 + ) + val set = new OpenHashSet[Double]() + spark45599Repro.foreach(set.add) + assert(set.size == 6) + val zeroPos = set.getPos(0.0) + val negZeroPos = set.getPos(-0.0) + assert(zeroPos != negZeroPos) + } + + test("SPARK-45599: NaN and NaN are the same but not equal") { + // Any mathematical comparison to NaN will return false, but when we place it in + // a hash set we want the lookup to work like a "normal" value. + val set = new OpenHashSet[Double]() + set.add(Double.NaN) + set.add(Double.NaN) + assert(set.contains(Double.NaN)) + assert(set.size == 1) + } } diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 16e0e3e30c9e5..6bf840cee2831 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -145,3 +145,4 @@ empty.proto .*\.proto.bin LimitedInputStream.java TimSort.java +.*\.har diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index 3737382eb86e2..792a9aa4e9793 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -81,7 +81,7 @@ if (!(Test-Path $tools)) { # ========================== Maven # Push-Location $tools # -# $mavenVer = "3.8.8" +# $mavenVer = "3.9.6" # Start-FileDownload "https://archive.apache.org/dist/maven/maven-3/$mavenVer/binaries/apache-maven-$mavenVer-bin.zip" "maven.zip" # # # extract diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh index 88398bc14dd02..ea3105b3d0a70 100755 --- a/dev/create-release/do-release-docker.sh +++ b/dev/create-release/do-release-docker.sh @@ -84,8 +84,8 @@ if [ ! -z "$RELEASE_STEP" ] && [ "$RELEASE_STEP" = "finalize" ]; then error "Exiting." fi - if [ -z "$PYPI_PASSWORD" ]; then - stty -echo && printf "PyPi password: " && read PYPI_PASSWORD && printf '\n' && stty echo + if [ -z "$PYPI_API_TOKEN" ]; then + stty -echo && printf "PyPi API token: " && read PYPI_API_TOKEN && printf '\n' && stty echo fi fi @@ -142,7 +142,7 @@ GIT_NAME=$GIT_NAME GIT_EMAIL=$GIT_EMAIL GPG_KEY=$GPG_KEY ASF_PASSWORD=$ASF_PASSWORD -PYPI_PASSWORD=$PYPI_PASSWORD +PYPI_API_TOKEN=$PYPI_API_TOKEN GPG_PASSPHRASE=$GPG_PASSPHRASE RELEASE_STEP=$RELEASE_STEP USER=$USER diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index e0588ae934cd2..99841916cf293 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -95,8 +95,8 @@ init_java init_maven_sbt if [[ "$1" == "finalize" ]]; then - if [[ -z "$PYPI_PASSWORD" ]]; then - error 'The environment variable PYPI_PASSWORD is not set. Exiting.' + if [[ -z "$PYPI_API_TOKEN" ]]; then + error 'The environment variable PYPI_API_TOKEN is not set. Exiting.' fi git config --global user.name "$GIT_NAME" @@ -104,22 +104,27 @@ if [[ "$1" == "finalize" ]]; then # Create the git tag for the new release echo "Creating the git tag for the new release" - rm -rf spark - git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b master - cd spark - git tag "v$RELEASE_VERSION" "$RELEASE_TAG" - git push origin "v$RELEASE_VERSION" - cd .. - rm -rf spark - echo "git tag v$RELEASE_VERSION created" + if check_for_tag "v$RELEASE_VERSION"; then + echo "v$RELEASE_VERSION already exists. Skip creating it." + else + rm -rf spark + git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b master + cd spark + git tag "v$RELEASE_VERSION" "$RELEASE_TAG" + git push origin "v$RELEASE_VERSION" + cd .. + rm -rf spark + echo "git tag v$RELEASE_VERSION created" + fi # download PySpark binary from the dev directory and upload to PyPi. echo "Uploading PySpark to PyPi" svn co --depth=empty "$RELEASE_STAGING_LOCATION/$RELEASE_TAG-bin" svn-spark cd svn-spark - svn update "pyspark-$RELEASE_VERSION.tar.gz" - svn update "pyspark-$RELEASE_VERSION.tar.gz.asc" - TWINE_USERNAME=spark-upload TWINE_PASSWORD="$PYPI_PASSWORD" twine upload \ + PYSPARK_VERSION=`echo "$RELEASE_VERSION" | sed -e "s/-/./" -e "s/preview/dev/"` + svn update "pyspark-$PYSPARK_VERSION.tar.gz" + svn update "pyspark-$PYSPARK_VERSION.tar.gz.asc" + twine upload -u __token__ -p $PYPI_API_TOKEN \ --repository-url https://upload.pypi.org/legacy/ \ "pyspark-$RELEASE_VERSION.tar.gz" \ "pyspark-$RELEASE_VERSION.tar.gz.asc" diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index 85155b67bd5a3..789915d018def 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -42,7 +42,7 @@ ARG APT_INSTALL="apt-get install --no-install-recommends -y" # We should use the latest Sphinx version once this is fixed. # TODO(SPARK-35375): Jinja2 3.0.0+ causes error when building with Sphinx. # See also https://issues.apache.org/jira/browse/SPARK-35375. -ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.20.3 pydata_sphinx_theme==0.8.0 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.5.3 pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.56.0 protobuf==4.21.6 grpcio-status==1.56.0 googleapis-common-protos==1.56.4" +ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.20.3 pydata_sphinx_theme==0.8.0 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 sphinx-copybutton==0.5.2 pandas==2.0.3 pyarrow==4.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.56.0 protobuf==4.21.6 grpcio-status==1.56.0 googleapis-common-protos==1.56.4" ARG GEM_PKGS="bundler:2.3.8" # Install extra needed repos and refresh. diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 1d02f8dba567e..a9d63c1ad0f99 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -4,7 +4,7 @@ JTransforms/3.1//JTransforms-3.1.jar RoaringBitmap/0.9.45//RoaringBitmap-0.9.45.jar ST4/4.0.4//ST4-4.0.4.jar activation/1.1.1//activation-1.1.1.jar -aircompressor/0.25//aircompressor-0.25.jar +aircompressor/0.27//aircompressor-0.27.jar algebra_2.12/2.0.1//algebra_2.12-2.0.1.jar aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar @@ -36,14 +36,14 @@ cats-kernel_2.12/2.1.1//cats-kernel_2.12-2.1.1.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.12/0.10.0//chill_2.12-0.10.0.jar commons-cli/1.5.0//commons-cli-1.5.0.jar -commons-codec/1.16.0//commons-codec-1.16.0.jar +commons-codec/1.16.1//commons-codec-1.16.1.jar commons-collections/3.2.2//commons-collections-3.2.2.jar commons-collections4/4.4//commons-collections4-4.4.jar commons-compiler/3.1.9//commons-compiler-3.1.9.jar commons-compress/1.23.0//commons-compress-1.23.0.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.13.0//commons-io-2.13.0.jar +commons-io/2.16.1//commons-io-2.16.1.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.12.0//commons-lang3-3.12.0.jar commons-logging/1.1.3//commons-logging-1.1.3.jar @@ -130,8 +130,8 @@ jersey-container-servlet/2.40//jersey-container-servlet-2.40.jar jersey-hk2/2.40//jersey-hk2-2.40.jar jersey-server/2.40//jersey-server-2.40.jar jettison/1.1//jettison-1.1.jar -jetty-util-ajax/9.4.52.v20230823//jetty-util-ajax-9.4.52.v20230823.jar -jetty-util/9.4.52.v20230823//jetty-util-9.4.52.v20230823.jar +jetty-util-ajax/9.4.54.v20240208//jetty-util-ajax-9.4.54.v20240208.jar +jetty-util/9.4.54.v20240208//jetty-util-9.4.54.v20240208.jar jline/2.14.6//jline-2.14.6.jar joda-time/2.12.5//joda-time-2.12.5.jar jodd-core/3.5.2//jodd-core-3.5.2.jar @@ -207,14 +207,14 @@ netty-transport-native-unix-common/4.1.96.Final//netty-transport-native-unix-com netty-transport/4.1.96.Final//netty-transport-4.1.96.Final.jar objenesis/3.3//objenesis-3.3.jar okhttp/3.12.12//okhttp-3.12.12.jar -okio/1.15.0//okio-1.15.0.jar +okio/1.17.6//okio-1.17.6.jar opencsv/2.3//opencsv-2.3.jar opentracing-api/0.33.0//opentracing-api-0.33.0.jar opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar opentracing-util/0.33.0//opentracing-util-0.33.0.jar -orc-core/1.9.1/shaded-protobuf/orc-core-1.9.1-shaded-protobuf.jar -orc-mapreduce/1.9.1/shaded-protobuf/orc-mapreduce-1.9.1-shaded-protobuf.jar -orc-shims/1.9.1//orc-shims-1.9.1.jar +orc-core/1.9.4/shaded-protobuf/orc-core-1.9.4-shaded-protobuf.jar +orc-mapreduce/1.9.4/shaded-protobuf/orc-mapreduce-1.9.4-shaded-protobuf.jar +orc-shims/1.9.4//orc-shims-1.9.4.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar @@ -238,7 +238,7 @@ shims/0.9.45//shims-0.9.45.jar slf4j-api/2.0.7//slf4j-api-2.0.7.jar snakeyaml-engine/2.6//snakeyaml-engine-2.6.jar snakeyaml/2.0//snakeyaml-2.0.jar -snappy-java/1.1.10.3//snappy-java-1.1.10.3.jar +snappy-java/1.1.10.5//snappy-java-1.1.10.5.jar spire-macros_2.12/0.17.0//spire-macros_2.12-0.17.0.jar spire-platform_2.12/0.17.0//spire-platform_2.12-0.17.0.jar spire-util_2.12/0.17.0//spire-util_2.12-0.17.0.jar diff --git a/dev/free_disk_space_container b/dev/free_disk_space_container new file mode 100755 index 0000000000000..cc3b74643e4fa --- /dev/null +++ b/dev/free_disk_space_container @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +echo "==================================" +echo "Free up disk space on CI system" +echo "==================================" + +echo "Listing 100 largest packages" +dpkg-query -Wf '${Installed-Size}\t${Package}\n' | sort -n | tail -n 100 +df -h + +echo "Removing large packages" +rm -rf /__t/CodeQL +rm -rf /__t/go +rm -rf /__t/node + +df -h diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index af8e1a980f93c..f0b88666c040d 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -65,10 +65,10 @@ RUN Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='ht ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" RUN pypy3 -m pip install numpy 'pandas<=2.0.3' scipy coverage matplotlib -RUN python3.9 -m pip install numpy pyarrow 'pandas<=2.0.3' scipy unittest-xml-reporting plotly>=4.8 'mlflow>=2.3.1' coverage matplotlib openpyxl 'memory-profiler==0.60.0' 'scikit-learn==1.1.*' +RUN python3.9 -m pip install 'numpy==1.25.1' 'pyarrow==12.0.1' 'pandas<=2.0.3' scipy unittest-xml-reporting plotly>=4.8 'mlflow>=2.3.1' coverage 'matplotlib==3.7.2' openpyxl 'memory-profiler==0.60.0' 'scikit-learn==1.1.*' # Add Python deps for Spark Connect. -RUN python3.9 -m pip install grpcio protobuf googleapis-common-protos grpcio-status +RUN python3.9 -m pip install 'grpcio>=1.48,<1.57' 'grpcio-status>=1.48,<1.57' 'protobuf==3.20.3' 'googleapis-common-protos==1.56.4' # Add torch as a testing dependency for TorchDistributor -RUN python3.9 -m pip install torch torchvision torcheval +RUN python3.9 -m pip install 'torch==2.0.1' 'torchvision==0.15.2' torcheval diff --git a/dev/is-changed.py b/dev/is-changed.py index 85f0d3cda6df4..1962e244d5dd7 100755 --- a/dev/is-changed.py +++ b/dev/is-changed.py @@ -17,6 +17,8 @@ # limitations under the License. # +import warnings +import traceback import os import sys from argparse import ArgumentParser @@ -82,4 +84,8 @@ def main(): if __name__ == "__main__": - main() + try: + main() + except Exception: + warnings.warn(f"Ignored exception:\n\n{traceback.format_exc()}") + print("true") diff --git a/dev/lint-python b/dev/lint-python index d040493c86c42..7ccd32451acc8 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -118,6 +118,7 @@ function mypy_annotation_test { echo "starting mypy annotations test..." MYPY_REPORT=$( ($MYPY_BUILD \ + --python-executable $PYTHON_EXECUTABLE \ --namespace-packages \ --config-file python/mypy.ini \ --cache-dir /tmp/.mypy_cache/ \ @@ -177,6 +178,7 @@ function mypy_examples_test { echo "starting mypy examples test..." MYPY_REPORT=$( (MYPYPATH=python $MYPY_BUILD \ + --python-executable $PYTHON_EXECUTABLE \ --namespace-packages \ --config-file python/mypy.ini \ --exclude "mllib/*" \ diff --git a/dev/requirements.txt b/dev/requirements.txt index 38a9b2447108c..0749af75aa4be 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -3,7 +3,7 @@ py4j # PySpark dependencies (optional) numpy -pyarrow +pyarrow<13.0.0 pandas scipy plotly @@ -37,6 +37,7 @@ numpydoc jinja2<3.0.0 sphinx<3.1.0 sphinx-plotly-directive +sphinx-copybutton<0.5.3 docutils<0.18.0 # See SPARK-38279. markupsafe==2.0.1 diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 33d253a47ea07..d29fc8726018d 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -168,6 +168,15 @@ def __hash__(self): ], ) +sketch = Module( + name="sketch", + dependencies=[tags], + source_file_regexes=[ + "common/sketch/", + ], + sbt_test_goals=["sketch/test"], +) + core = Module( name="core", dependencies=[kvstore, network_common, network_shuffle, unsafe, launcher], @@ -181,7 +190,7 @@ def __hash__(self): catalyst = Module( name="catalyst", - dependencies=[tags, core], + dependencies=[tags, sketch, core], source_file_regexes=[ "sql/catalyst/", ], @@ -295,15 +304,6 @@ def __hash__(self): ], ) -sketch = Module( - name="sketch", - dependencies=[tags], - source_file_regexes=[ - "common/sketch/", - ], - sbt_test_goals=["sketch/test"], -) - graphx = Module( name="graphx", dependencies=[tags, core], diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index d7967ac3afa90..36cc7a4f994dc 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -140,4 +140,8 @@ for HADOOP_HIVE_PROFILE in "${HADOOP_HIVE_PROFILES[@]}"; do fi done +if [[ -d "$FWDIR/dev/pr-deps" ]]; then + rm -rf "$FWDIR/dev/pr-deps" +fi + exit 0 diff --git a/docs/_config.yml b/docs/_config.yml index 19cadd69e61ba..3dea0c82204bd 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -19,8 +19,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 3.5.0 -SPARK_VERSION_SHORT: 3.5.0 +SPARK_VERSION: 3.5.4-SNAPSHOT +SPARK_VERSION_SHORT: 3.5.4 SCALA_BINARY_VERSION: "2.12" SCALA_VERSION: "2.12.18" MESOS_VERSION: 1.0.0 @@ -40,9 +40,11 @@ DOCSEARCH_SCRIPT: | inputSelector: '#docsearch-input', enhancedSearchInput: true, algoliaOptions: { - 'facetFilters': ["version:3.5.0"] + 'facetFilters': ["version:3.5.4"] }, debug: false // Set debug to true if you want to inspect the dropdown }); permalink: 404.html + +exclude: ['README.md'] diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 9b7c469246165..5116472eaa769 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -32,6 +32,25 @@ + {% production %} + + + + {% endproduction %} + diff --git a/docs/sql-data-sources-load-save-functions.md b/docs/sql-data-sources-load-save-functions.md index 9d0a3f9c72b9a..31f6d944bc972 100644 --- a/docs/sql-data-sources-load-save-functions.md +++ b/docs/sql-data-sources-load-save-functions.md @@ -218,7 +218,7 @@ present. It is important to realize that these save modes do not utilize any loc atomic. Additionally, when performing an `Overwrite`, the data will be deleted before writing out the new data. - +
    diff --git a/docs/sql-data-sources-orc.md b/docs/sql-data-sources-orc.md index 4e492598f595d..561f601aa4e56 100644 --- a/docs/sql-data-sources-orc.md +++ b/docs/sql-data-sources-orc.md @@ -129,7 +129,7 @@ When reading from Hive metastore ORC tables and inserting to Hive metastore ORC ### Configuration -
    Scala/JavaAny LanguageMeaning
    SaveMode.ErrorIfExists (default)
    +
    @@ -230,7 +230,7 @@ Data source options of ORC can be set via: * `DataStreamWriter` * `OPTIONS` clause at [CREATE TABLE USING DATA_SOURCE](sql-ref-syntax-ddl-create-table-datasource.html) -
    Property NameDefaultMeaningSince Version
    spark.sql.orc.impl
    +
    diff --git a/docs/sql-data-sources-parquet.md b/docs/sql-data-sources-parquet.md index 925e47504e5ef..707871e798026 100644 --- a/docs/sql-data-sources-parquet.md +++ b/docs/sql-data-sources-parquet.md @@ -386,7 +386,7 @@ Data source options of Parquet can be set via: * `DataStreamWriter` * `OPTIONS` clause at [CREATE TABLE USING DATA_SOURCE](sql-ref-syntax-ddl-create-table-datasource.html) -
    Property NameDefaultMeaningScope
    mergeSchema
    +
    @@ -434,7 +434,7 @@ Other generic options can be found in +
    Property NameDefaultMeaningScope
    datetimeRebaseMode
    @@ -616,14 +616,15 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession - + diff --git a/docs/sql-data-sources-protobuf.md b/docs/sql-data-sources-protobuf.md index f92a8f20b3570..c8ee139e344fe 100644 --- a/docs/sql-data-sources-protobuf.md +++ b/docs/sql-data-sources-protobuf.md @@ -18,7 +18,10 @@ license: | limitations under the License. --- -Since Spark 3.4.0 release, [Spark SQL](https://spark.apache.org/docs/latest/sql-programming-guide.html) provides built-in support for reading and writing protobuf data. +* This will become a table of contents (this text will be scraped). +{:toc} + +Since Spark 3.4.0 release, [Spark SQL](sql-programming-guide.html) provides built-in support for reading and writing protobuf data. ## Deploying The `spark-protobuf` module is external and not included in `spark-submit` or `spark-shell` by default. @@ -46,45 +49,53 @@ Kafka key-value record will be augmented with some metadata, such as the ingesti Spark SQL schema is generated based on the protobuf descriptor file or protobuf class passed to `from_protobuf` and `to_protobuf`. The specified protobuf class or protobuf descriptor file must match the data, otherwise, the behavior is undefined: it may fail or return arbitrary results. -### Python +
    + +
    + +
    +This div is only used to make markdown editor/viewer happy and does not display on web + ```python +
    + +{% highlight python %} + from pyspark.sql.protobuf.functions import from_protobuf, to_protobuf -# `from_protobuf` and `to_protobuf` provides two schema choices. Via Protobuf descriptor file, +# from_protobuf and to_protobuf provide two schema choices. Via Protobuf descriptor file, # or via shaded Java class. # give input .proto protobuf schema -# syntax = "proto3" +# syntax = "proto3" # message AppEvent { -# string name = 1; -# int64 id = 2; -# string context = 3; +# string name = 1; +# int64 id = 2; +# string context = 3; # } - -df = spark\ -.readStream\ -.format("kafka")\ -.option("kafka.bootstrap.servers", "host1:port1,host2:port2")\ -.option("subscribe", "topic1")\ -.load() +df = spark + .readStream + .format("kafka")\ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() # 1. Decode the Protobuf data of schema `AppEvent` into a struct; # 2. Filter by column `name`; # 3. Encode the column `event` in Protobuf format. # The Protobuf protoc command can be used to generate a protobuf descriptor file for give .proto file. -output = df\ -.select(from_protobuf("value", "AppEvent", descriptorFilePath).alias("event"))\ -.where('event.name == "alice"')\ -.select(to_protobuf("event", "AppEvent", descriptorFilePath).alias("event")) +output = df + .select(from_protobuf("value", "AppEvent", descriptorFilePath).alias("event")) + .where('event.name == "alice"') + .select(to_protobuf("event", "AppEvent", descriptorFilePath).alias("event")) # Alternatively, you can decode and encode the SQL columns into protobuf format using protobuf # class name. The specified Protobuf class must match the data, otherwise the behavior is undefined: # it may fail or return arbitrary result. To avoid conflicts, the jar file containing the # 'com.google.protobuf.*' classes should be shaded. An example of shading can be found at # https://github.com/rangadi/shaded-protobuf-classes. - -output = df\ -.select(from_protobuf("value", "org.sparkproject.spark_protobuf.protobuf.AppEvent").alias("event"))\ -.where('event.name == "alice"') +output = df + .select(from_protobuf("value", "org.sparkproject.spark_protobuf.protobuf.AppEvent").alias("event")) + .where('event.name == "alice"') output.printSchema() # root @@ -94,52 +105,66 @@ output.printSchema() # | |-- context: string (nullable = true) output = output -.select(to_protobuf("event", "org.sparkproject.spark_protobuf.protobuf.AppEvent").alias("event")) - -query = output\ -.writeStream\ -.format("kafka")\ -.option("kafka.bootstrap.servers", "host1:port1,host2:port2")\ -.option("topic", "topic2")\ -.start() + .select(to_protobuf("event", "org.sparkproject.spark_protobuf.protobuf.AppEvent").alias("event")) + +query = output + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2")\ + .option("topic", "topic2") + .start() + +{% endhighlight %} + +
    ``` +
    + +
    + +
    + +
    +This div is only used to make markdown editor/viewer happy and does not display on web -### Scala ```scala +
    + +{% highlight scala %} import org.apache.spark.sql.protobuf.functions._ -// `from_protobuf` and `to_protobuf` provides two schema choices. Via Protobuf descriptor file, +// `from_protobuf` and `to_protobuf` provides two schema choices. Via the protobuf descriptor file, // or via shaded Java class. // give input .proto protobuf schema -// syntax = "proto3" +// syntax = "proto3" // message AppEvent { -// string name = 1; -// int64 id = 2; -// string context = 3; +// string name = 1; +// int64 id = 2; +// string context = 3; // } val df = spark -.readStream -.format("kafka") -.option("kafka.bootstrap.servers", "host1:port1,host2:port2") -.option("subscribe", "topic1") -.load() + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() // 1. Decode the Protobuf data of schema `AppEvent` into a struct; // 2. Filter by column `name`; // 3. Encode the column `event` in Protobuf format. // The Protobuf protoc command can be used to generate a protobuf descriptor file for give .proto file. val output = df -.select(from_protobuf($"value", "AppEvent", descriptorFilePath) as $"event") -.where("event.name == \"alice\"") -.select(to_protobuf($"user", "AppEvent", descriptorFilePath) as $"event") + .select(from_protobuf($"value", "AppEvent", descriptorFilePath) as $"event") + .where("event.name == \"alice\"") + .select(to_protobuf($"user", "AppEvent", descriptorFilePath) as $"event") val query = output -.writeStream -.format("kafka") -.option("kafka.bootstrap.servers", "host1:port1,host2:port2") -.option("topic", "topic2") -.start() + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start() // Alternatively, you can decode and encode the SQL columns into protobuf format using protobuf // class name. The specified Protobuf class must match the data, otherwise the behavior is undefined: @@ -147,8 +172,8 @@ val query = output // 'com.google.protobuf.*' classes should be shaded. An example of shading can be found at // https://github.com/rangadi/shaded-protobuf-classes. var output = df -.select(from_protobuf($"value", "org.example.protos..AppEvent") as $"event") -.where("event.name == \"alice\"") + .select(from_protobuf($"value", "org.example.protos..AppEvent") as $"event") + .where("event.name == \"alice\"") output.printSchema() // root @@ -160,43 +185,56 @@ output.printSchema() output = output.select(to_protobuf($"event", "org.sparkproject.spark_protobuf.protobuf.AppEvent") as $"event") val query = output -.writeStream -.format("kafka") -.option("kafka.bootstrap.servers", "host1:port1,host2:port2") -.option("topic", "topic2") -.start() + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start() + +{% endhighlight %} + +
    ``` +
    +
    + +
    + +
    +This div is only used to make markdown editor/viewer happy and does not display on web -### Java ```java +
    + +{% highlight java %} import static org.apache.spark.sql.functions.col; import static org.apache.spark.sql.protobuf.functions.*; -// `from_protobuf` and `to_protobuf` provides two schema choices. Via Protobuf descriptor file, +// `from_protobuf` and `to_protobuf` provides two schema choices. Via the protobuf descriptor file, // or via shaded Java class. // give input .proto protobuf schema -// syntax = "proto3" +// syntax = "proto3" // message AppEvent { -// string name = 1; -// int64 id = 2; -// string context = 3; +// string name = 1; +// int64 id = 2; +// string context = 3; // } Dataset df = spark -.readStream() -.format("kafka") -.option("kafka.bootstrap.servers", "host1:port1,host2:port2") -.option("subscribe", "topic1") -.load(); + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load(); // 1. Decode the Protobuf data of schema `AppEvent` into a struct; // 2. Filter by column `name`; // 3. Encode the column `event` in Protobuf format. // The Protobuf protoc command can be used to generate a protobuf descriptor file for give .proto file. Dataset output = df -.select(from_protobuf(col("value"), "AppEvent", descriptorFilePath).as("event")) -.where("event.name == \"alice\"") -.select(to_protobuf(col("event"), "AppEvent", descriptorFilePath).as("event")); + .select(from_protobuf(col("value"), "AppEvent", descriptorFilePath).as("event")) + .where("event.name == \"alice\"") + .select(to_protobuf(col("event"), "AppEvent", descriptorFilePath).as("event")); // Alternatively, you can decode and encode the SQL columns into protobuf format using protobuf // class name. The specified Protobuf class must match the data, otherwise the behavior is undefined: @@ -204,10 +242,10 @@ Dataset output = df // 'com.google.protobuf.*' classes should be shaded. An example of shading can be found at // https://github.com/rangadi/shaded-protobuf-classes. Dataset output = df -.select( - from_protobuf(col("value"), - "org.sparkproject.spark_protobuf.protobuf.AppEvent").as("event")) -.where("event.name == \"alice\"") + .select( + from_protobuf(col("value"), + "org.sparkproject.spark_protobuf.protobuf.AppEvent").as("event")) + .where("event.name == \"alice\"") output.printSchema() // root @@ -221,19 +259,28 @@ output = output.select( "org.sparkproject.spark_protobuf.protobuf.AppEvent").as("event")); StreamingQuery query = output -.writeStream() -.format("kafka") -.option("kafka.bootstrap.servers", "host1:port1,host2:port2") -.option("topic", "topic2") -.start(); + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start(); + +{% endhighlight %} + +
    ``` +
    +
    + +
    ## Supported types for Protobuf -> Spark SQL conversion + Currently Spark supports reading [protobuf scalar types](https://developers.google.com/protocol-buffers/docs/proto3#scalar), [enum types](https://developers.google.com/protocol-buffers/docs/proto3#enum), [nested type](https://developers.google.com/protocol-buffers/docs/proto3#nested), and [maps type](https://developers.google.com/protocol-buffers/docs/proto3#maps) under messages of Protobuf. In addition to the these types, `spark-protobuf` also introduces support for Protobuf `OneOf` fields. which allows you to handle messages that can have multiple possible sets of fields, but only one set can be present at a time. This is useful for situations where the data you are working with is not always in the same format, and you need to be able to handle messages with different sets of fields without encountering errors. -
    Property NameDefaultMeaningSince Version
    spark.sql.parquet.binaryAsString3.3.0
    spark.sql.parquet.timestampNTZ.enabledspark.sql.parquet.inferTimestampNTZ.enabled true - Enables TIMESTAMP_NTZ support for Parquet reads and writes. - When enabled, TIMESTAMP_NTZ values are written as Parquet timestamp - columns with annotation isAdjustedToUTC = false and are inferred in a similar way. - When disabled, such values are read as TIMESTAMP_LTZ and have to be - converted to TIMESTAMP_LTZ for writes. + When enabled, Parquet timestamp columns with annotation isAdjustedToUTC = false + are inferred as TIMESTAMP_NTZ type during schema inference. Otherwise, all the Parquet + timestamp columns are inferred as TIMESTAMP_LTZ types. Note that Spark writes the + output schema into Parquet's footer metadata on file writing and leverages it on file + reading. Thus this configuration only affects the schema inference on Parquet files + which are not written by Spark. 3.4.0
    - +
    Protobuf typeSpark SQL type
    + @@ -282,16 +329,12 @@ In addition to the these types, `spark-protobuf` also introduces support for Pro - - - -
    Protobuf typeSpark SQL type
    boolean BooleanTypeOneOf Struct
    AnyStructType
    It also supports reading the following Protobuf types [Timestamp](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#timestamp) and [Duration](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#duration) - - +
    Protobuf logical typeProtobuf schemaSpark SQL type
    + @@ -305,10 +348,11 @@ It also supports reading the following Protobuf types [Timestamp](https://develo
    Protobuf logical typeProtobuf schemaSpark SQL type
    duration MessageType{seconds: Long, nanos: Int}
    ## Supported types for Spark SQL -> Protobuf conversion + Spark supports the writing of all Spark SQL types into Protobuf. For most types, the mapping from Spark types to Protobuf types is straightforward (e.g. IntegerType gets converted to int); - - +
    Spark SQL typeProtobuf type
    + @@ -356,15 +400,23 @@ Spark supports the writing of all Spark SQL types into Protobuf. For most types,
    Spark SQL typeProtobuf type
    BooleanType boolean
    ## Handling circular references protobuf fields + One common issue that can arise when working with Protobuf data is the presence of circular references. In Protobuf, a circular reference occurs when a field refers back to itself or to another field that refers back to the original field. This can cause issues when parsing the data, as it can result in infinite loops or other unexpected behavior. -To address this issue, the latest version of spark-protobuf introduces a new feature: the ability to check for circular references through field types. This allows users use the `recursive.fields.max.depth` option to specify the maximum number of levels of recursion to allow when parsing the schema. By default, `spark-protobuf` will not permit recursive fields by setting `recursive.fields.max.depth` to -1. However, you can set this option to 0 to 10 if needed. +To address this issue, the latest version of spark-protobuf introduces a new feature: the ability to check for circular references through field types. This allows users use the `recursive.fields.max.depth` option to specify the maximum number of levels of recursion to allow when parsing the schema. By default, `spark-protobuf` will not permit recursive fields by setting `recursive.fields.max.depth` to -1. However, you can set this option to 0 to 10 if needed. Setting `recursive.fields.max.depth` to 0 drops all recursive fields, setting it to 1 allows it to be recursed once, and setting it to 2 allows it to be recursed twice. A `recursive.fields.max.depth` value greater than 10 is not allowed, as it can lead to performance issues and even stack overflows. SQL Schema for the below protobuf message will vary based on the value of `recursive.fields.max.depth`. -```proto -syntax = "proto3" +
    +
    +This div is only used to make markdown editor/viewer happy and does not display on web + +```protobuf +
    + +{% highlight protobuf %} +syntax = "proto3" message Person { string name = 1; Person bff = 2 @@ -376,4 +428,9 @@ message Person { 0: struct 1: struct> 2: struct>> ... -``` \ No newline at end of file + +{% endhighlight %} +
    +``` +
    +
    \ No newline at end of file diff --git a/docs/sql-data-sources-text.md b/docs/sql-data-sources-text.md index bb485d29c396a..aed8a2e9942fb 100644 --- a/docs/sql-data-sources-text.md +++ b/docs/sql-data-sources-text.md @@ -47,7 +47,7 @@ Data source options of text can be set via: * `DataStreamWriter` * `OPTIONS` clause at [CREATE TABLE USING DATA_SOURCE](sql-ref-syntax-ddl-create-table-datasource.html) - +
    diff --git a/docs/sql-distributed-sql-engine-spark-sql-cli.md b/docs/sql-distributed-sql-engine-spark-sql-cli.md index a67e009b9ae10..6d506cbb09c21 100644 --- a/docs/sql-distributed-sql-engine-spark-sql-cli.md +++ b/docs/sql-distributed-sql-engine-spark-sql-cli.md @@ -62,7 +62,7 @@ For example: `/path/to/spark-sql-cli.sql` equals to `file:///path/to/spark-sql-c ## Supported comment types -
    Property NameDefaultMeaningScope
    wholetext
    +
    @@ -115,7 +115,7 @@ Use `;` (semicolon) to terminate commands. Notice: ``` However, if ';' is the end of the line, it terminates the SQL statement. The example above will be terminated into `/* This is a comment contains ` and `*/ SELECT 1`, Spark will submit these two commands separated and throw parser error (`unclosed bracketed comment` and `Syntax error at or near '*/'`). -
    CommentExample
    simple comment
    +
    diff --git a/docs/sql-error-conditions-invalid-default-value-error-class.md b/docs/sql-error-conditions-invalid-default-value-error-class.md index 466b5a9274cad..92c70ce69fc5f 100644 --- a/docs/sql-error-conditions-invalid-default-value-error-class.md +++ b/docs/sql-error-conditions-invalid-default-value-error-class.md @@ -29,6 +29,10 @@ This error class has the following derived error classes: which requires `` type, but the statement provided a value of incompatible `` type. +## NOT_CONSTANT + +which is not a constant expression whose equivalent value is known at query planning time. + ## SUBQUERY_EXPRESSION which contains subquery expressions. diff --git a/docs/sql-error-conditions-sqlstates.md b/docs/sql-error-conditions-sqlstates.md index 5529c961b3bfb..49cfb56b36626 100644 --- a/docs/sql-error-conditions-sqlstates.md +++ b/docs/sql-error-conditions-sqlstates.md @@ -33,7 +33,7 @@ Spark SQL uses the following `SQLSTATE` classes: ## Class `0A`: feature not supported -
    CommandDescription
    quit or exit
    +
    @@ -48,7 +48,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    0A000
    ## Class `21`: cardinality violation - +
    @@ -63,7 +63,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    21000
    ## Class `22`: data exception - +
    @@ -168,7 +168,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    22003
    ## Class `23`: integrity constraint violation - +
    @@ -183,7 +183,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    23505
    ## Class `2B`: dependent privilege descriptors still exist - +
    @@ -198,7 +198,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    2BP01
    ## Class `38`: external routine exception - +
    @@ -213,7 +213,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    38000
    ## Class `39`: external routine invocation exception - +
    @@ -228,7 +228,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    39000
    ## Class `42`: syntax error or access rule violation - +
    @@ -648,7 +648,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    42000
    ## Class `46`: java ddl 1 - +
    @@ -672,7 +672,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    46110
    ## Class `53`: insufficient resources - +
    @@ -687,7 +687,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    53200
    ## Class `54`: program limit exceeded - +
    @@ -702,7 +702,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    54000
    ## Class `HY`: CLI-specific condition - +
    @@ -717,7 +717,7 @@ Spark SQL uses the following `SQLSTATE` classes:
    SQLSTATEDescription and issuing error classes
    HY008
    ## Class `XX`: internal error - +
    diff --git a/docs/sql-error-conditions-unsupported-generator-error-class.md b/docs/sql-error-conditions-unsupported-generator-error-class.md index 7960c14767d17..38b3bbfaa3c3c 100644 --- a/docs/sql-error-conditions-unsupported-generator-error-class.md +++ b/docs/sql-error-conditions-unsupported-generator-error-class.md @@ -27,7 +27,7 @@ This error class has the following derived error classes: ## MULTI_GENERATOR -only one generator allowed per `` clause but found ``: ``. +only one generator allowed per SELECT clause but found ``: ``. ## NESTED_IN_EXPRESSIONS diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index e7df1aa9a4f9c..0cf05748f58f0 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -55,6 +55,15 @@ See '``/sql-migration-guide.html#query-engine'. Column or field `` is ambiguous and has `` matches. +### AMBIGUOUS_COLUMN_REFERENCE + +[SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Column `` is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same. +This column points to one of the DataFrame but Spark is unable to figure out which one. +Please alias the DataFrames with different names via `DataFrame.alias` before joining them, +and specify the column using qualified name, e.g. `df.alias("a").join(df.alias("b"), col("a.id") > col("b.id"))`. + ### AMBIGUOUS_LATERAL_COLUMN_ALIAS [SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -87,6 +96,12 @@ Invalid as-of join. For more details see [AS_OF_JOIN](sql-error-conditions-as-of-join-error-class.html) +### AVRO_INCOMPATIBLE_READ_TYPE + +SQLSTATE: none assigned + +Cannot convert Avro `` to SQL `` because the original encoded data type is ``, however you're trying to read the field as ``, which would lead to an incorrect answer. To allow reading this field, enable the SQL configuration: "spark.sql.legacy.avro.allowIncompatibleSchema". + ### BATCH_METADATA_NOT_FOUND [SQLSTATE: 42K03](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 5fc323ec1b0ea..964f7de637e8b 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -22,6 +22,14 @@ license: | * Table of contents {:toc} +## Upgrading from Spark SQL 3.5.1 to 3.5.2 + +- Since 3.5.2, MySQL JDBC datasource will read TINYINT UNSIGNED as ShortType, while in 3.5.1, it was wrongly read as ByteType. + +## Upgrading from Spark SQL 3.5.0 to 3.5.1 + +- Since Spark 3.5.1, MySQL JDBC datasource will read TINYINT(n > 1) and TINYINT UNSIGNED as ByteType, while in Spark 3.5.0 and below, they were read as IntegerType. To restore the previous behavior, you can cast the column to the old type. + ## Upgrading from Spark SQL 3.4 to 3.5 - Since Spark 3.5, the JDBC options related to DS V2 pushdown are `true` by default. These options include: `pushDownAggregate`, `pushDownLimit`, `pushDownOffset` and `pushDownTableSample`. To restore the legacy behavior, please set them to `false`. e.g. set `spark.sql.catalog.your_catalog_name.pushDownAggregate` to `false`. @@ -30,6 +38,7 @@ license: | - Since Spark 3.5, the `plan` field is moved from `AnalysisException` to `EnhancedAnalysisException`. - Since Spark 3.5, `spark.sql.optimizer.canChangeCachedPlanOutputPartitioning` is enabled by default. To restore the previous behavior, set `spark.sql.optimizer.canChangeCachedPlanOutputPartitioning` to `false`. - Since Spark 3.5, the `array_insert` function is 1-based for negative indexes. It inserts new element at the end of input arrays for the index -1. To restore the previous behavior, set `spark.sql.legacy.negativeIndexInArrayInsert` to `true`. +- Since Spark 3.5, the Avro will throw `AnalysisException` when reading Interval types as Date or Timestamp types, or reading Decimal types with lower precision. To restore the legacy behavior, set `spark.sql.legacy.avro.allowIncompatibleSchema` to `true` ## Upgrading from Spark SQL 3.3 to 3.4 @@ -48,6 +57,8 @@ license: | - Since Spark 3.4, vectorized readers are enabled by default for the nested data types (array, map and struct). To restore the legacy behavior, set `spark.sql.orc.enableNestedColumnVectorizedReader` and `spark.sql.parquet.enableNestedColumnVectorizedReader` to `false`. - Since Spark 3.4, `BinaryType` is not supported in CSV datasource. In Spark 3.3 or earlier, users can write binary columns in CSV datasource, but the output content in CSV files is `Object.toString()` which is meaningless; meanwhile, if users read CSV tables with binary columns, Spark will throw an `Unsupported type: binary` exception. - Since Spark 3.4, bloom filter joins are enabled by default. To restore the legacy behavior, set `spark.sql.optimizer.runtime.bloomFilter.enabled` to `false`. + - Since Spark 3.4, when schema inference on external Parquet files, INT64 timestamps with annotation `isAdjustedToUTC=false` will be inferred as TimestampNTZ type instead of Timestamp type. To restore the legacy behavior, set `spark.sql.parquet.inferTimestampNTZ.enabled` to `false`. + - Since Spark 3.4, the behavior for `CREATE TABLE AS SELECT ...` is changed from OVERWRITE to APPEND when `spark.sql.legacy.allowNonEmptyLocationInCTAS` is set to `true`. Users are recommended to avoid CTAS with a non-empty table location. ## Upgrading from Spark SQL 3.2 to 3.3 @@ -97,6 +108,8 @@ license: | - Since Spark 3.3, the `unbase64` function throws error for a malformed `str` input. Use `try_to_binary(, 'base64')` to tolerate malformed input and return NULL instead. In Spark 3.2 and earlier, the `unbase64` function returns a best-efforts result for a malformed `str` input. + - Since Spark 3.3, when reading Parquet files that were not produced by Spark, Parquet timestamp columns with annotation `isAdjustedToUTC = false` are inferred as TIMESTAMP_NTZ type during schema inference. In Spark 3.2 and earlier, these columns are inferred as TIMESTAMP type. To restore the behavior before Spark 3.3, you can set `spark.sql.parquet.inferTimestampNTZ.enabled` to `false`. + - Since Spark 3.3.1 and 3.2.3, for `SELECT ... GROUP BY a GROUPING SETS (b)`-style SQL statements, `grouping__id` returns different values from Apache Spark 3.2.0, 3.2.1, 3.2.2, and 3.3.0. It computes based on user-given group-by expressions plus grouping set columns. To restore the behavior before 3.3.1 and 3.2.3, you can set `spark.sql.legacy.groupingIdWithAppendedUserGroupBy`. For details, see [SPARK-40218](https://issues.apache.org/jira/browse/SPARK-40218) and [SPARK-40562](https://issues.apache.org/jira/browse/SPARK-40562). ## Upgrading from Spark SQL 3.1 to 3.2 @@ -250,6 +263,8 @@ license: | - In Spark 3.0, the column metadata will always be propagated in the API `Column.name` and `Column.as`. In Spark version 2.4 and earlier, the metadata of `NamedExpression` is set as the `explicitMetadata` for the new column at the time the API is called, it won't change even if the underlying `NamedExpression` changes metadata. To restore the behavior before Spark 3.0, you can use the API `as(alias: String, metadata: Metadata)` with explicit metadata. + - When turning a Dataset to another Dataset, Spark will up cast the fields in the original Dataset to the type of corresponding fields in the target DataSet. In version 2.4 and earlier, this up cast is not very strict, e.g. `Seq("str").toDS.as[Int]` fails, but `Seq("str").toDS.as[Boolean]` works and throw NPE during execution. In Spark 3.0, the up cast is stricter and turning String into something else is not allowed, i.e. `Seq("str").toDS.as[Boolean]` will fail during analysis. To restore the behavior before Spark 3.0, set `spark.sql.legacy.doLooseUpcast` to `true`. + ### DDL Statements - In Spark 3.0, when inserting a value into a table column with a different data type, the type coercion is performed as per ANSI SQL standard. Certain unreasonable type conversions such as converting `string` to `int` and `double` to `boolean` are disallowed. A runtime exception is thrown if the value is out-of-range for the data type of the column. In Spark version 2.4 and below, type conversions during table insertion are allowed as long as they are valid `Cast`. When inserting an out-of-range value to an integral field, the low-order bits of the value is inserted(the same as Java/Scala numeric type casting). For example, if 257 is inserted to a field of byte type, the result is 1. The behavior is controlled by the option `spark.sql.storeAssignmentPolicy`, with a default value as "ANSI". Setting the option as "Legacy" restores the previous behavior. @@ -463,12 +478,10 @@ license: | need to specify a value with units like "30s" now, to avoid being interpreted as milliseconds; otherwise, the extremely short interval that results will likely cause applications to fail. - - When turning a Dataset to another Dataset, Spark will up cast the fields in the original Dataset to the type of corresponding fields in the target DataSet. In version 2.4 and earlier, this up cast is not very strict, e.g. `Seq("str").toDS.as[Int]` fails, but `Seq("str").toDS.as[Boolean]` works and throw NPE during execution. In Spark 3.0, the up cast is stricter and turning String into something else is not allowed, i.e. `Seq("str").toDS.as[Boolean]` will fail during analysis. To restore the behavior before 2.4.1, set `spark.sql.legacy.looseUpcast` to `true`. - ## Upgrading from Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. -
    SQLSTATEDescription and issuing error classes
    XX000
    +
    @@ -582,7 +595,7 @@ license: | - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below: - +
    diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index 1467409bb500d..2dec65cc553ed 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -34,7 +34,7 @@ memory usage and GC pressure. You can call `spark.catalog.uncacheTable("tableNam Configuration of in-memory caching can be done using the `setConf` method on `SparkSession` or by running `SET key=value` commands using SQL. - +
    @@ -62,7 +62,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `Sp The following options can also be used to tune the performance of query execution. It is possible that these options will be deprecated in future release as more optimizations are performed automatically. -
    Property NameDefaultMeaningSince Version
    spark.sql.inMemoryColumnarStorage.compressed
    +
    @@ -253,7 +253,7 @@ Adaptive Query Execution (AQE) is an optimization technique in Spark SQL that ma ### Coalescing Post Shuffle Partitions This feature coalesces the post shuffle partitions based on the map output statistics when both `spark.sql.adaptive.enabled` and `spark.sql.adaptive.coalescePartitions.enabled` configurations are true. This feature simplifies the tuning of shuffle partition number when running queries. You do not need to set a proper shuffle partition number to fit your dataset. Spark can pick the proper shuffle partition number at runtime once you set a large enough initial number of shuffle partitions via `spark.sql.adaptive.coalescePartitions.initialPartitionNum` configuration. -
    Property NameDefaultMeaningSince Version
    spark.sql.files.maxPartitionBytes
    +
    @@ -298,7 +298,7 @@ This feature coalesces the post shuffle partitions based on the map output stati
    Property NameDefaultMeaningSince Version
    spark.sql.adaptive.coalescePartitions.enabled
    ### Spliting skewed shuffle partitions - +
    @@ -320,7 +320,7 @@ This feature coalesces the post shuffle partitions based on the map output stati ### Converting sort-merge join to broadcast join AQE converts sort-merge join to broadcast hash join when the runtime statistics of any join side is smaller than the adaptive broadcast hash join threshold. This is not as efficient as planning a broadcast hash join in the first place, but it's better than keep doing the sort-merge join, as we can save the sorting of both the join sides, and read shuffle files locally to save network traffic(if `spark.sql.adaptive.localShuffleReader.enabled` is true) -
    Property NameDefaultMeaningSince Version
    spark.sql.adaptive.optimizeSkewsInRebalancePartitions.enabled
    +
    @@ -342,7 +342,7 @@ AQE converts sort-merge join to broadcast hash join when the runtime statistics ### Converting sort-merge join to shuffled hash join AQE converts sort-merge join to shuffled hash join when all post shuffle partitions are smaller than a threshold, the max threshold can see the config `spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold`. -
    Property NameDefaultMeaningSince Version
    spark.sql.adaptive.autoBroadcastJoinThreshold
    +
    @@ -356,7 +356,7 @@ AQE converts sort-merge join to shuffled hash join when all post shuffle partiti ### Optimizing Skew Join Data skew can severely downgrade the performance of join queries. This feature dynamically handles skew in sort-merge join by splitting (and replicating if needed) skewed tasks into roughly evenly sized tasks. It takes effect when both `spark.sql.adaptive.enabled` and `spark.sql.adaptive.skewJoin.enabled` configurations are enabled. -
    Property NameDefaultMeaningSince Version
    spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold
    +
    @@ -393,7 +393,7 @@ Data skew can severely downgrade the performance of join queries. This feature d
    Property NameDefaultMeaningSince Version
    spark.sql.adaptive.skewJoin.enabled
    ### Misc - +
    diff --git a/docs/sql-ref-datetime-pattern.md b/docs/sql-ref-datetime-pattern.md index 5e28a18acefa4..e5d5388f262e4 100644 --- a/docs/sql-ref-datetime-pattern.md +++ b/docs/sql-ref-datetime-pattern.md @@ -41,7 +41,7 @@ Spark uses pattern letters in the following table for date and timestamp parsing |**a**|am-pm-of-day|am-pm|PM| |**h**|clock-hour-of-am-pm (1-12)|number(2)|12| |**K**|hour-of-am-pm (0-11)|number(2)|0| -|**k**|clock-hour-of-day (1-24)|number(2)|0| +|**k**|clock-hour-of-day (1-24)|number(2)|1| |**H**|hour-of-day (0-23)|number(2)|0| |**m**|minute-of-hour|number(2)|30| |**s**|second-of-minute|number(2)|55| diff --git a/docs/sql-ref-syntax-ddl-create-table-datasource.md b/docs/sql-ref-syntax-ddl-create-table-datasource.md index 7920a8a558e3d..f645732a15df9 100644 --- a/docs/sql-ref-syntax-ddl-create-table-datasource.md +++ b/docs/sql-ref-syntax-ddl-create-table-datasource.md @@ -104,7 +104,9 @@ In general CREATE TABLE is creating a "pointer", and you need to make sure it po existing. An exception is file source such as parquet, json. If you don't specify the LOCATION, Spark will create a default table location for you. -For CREATE TABLE AS SELECT, Spark will overwrite the underlying data source with the data of the +For CREATE TABLE AS SELECT with LOCATION, Spark throws analysis exceptions if the given location +exists as a non-empty directory. If `spark.sql.legacy.allowNonEmptyLocationInCTAS` is set to true, +Spark overwrites the underlying data source with the data of the input query, to make sure the table gets created contains exactly the same data as the input query. ### Examples diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index 73b21a1f7c27b..5b30786bdd7f9 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -60,7 +60,7 @@ required by Keystone. The following table contains a list of Keystone mandatory parameters. PROVIDER can be any (alphanumeric) name. -
    Property NameDefaultMeaningSince Version
    spark.sql.adaptive.optimizer.excludedRules
    +
    diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 591a4415bb1a5..11a52232510fd 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -243,7 +243,7 @@ interval in the [Spark Streaming Programming Guide](streaming-programming-guide. The following table summarizes the characteristics of both types of receivers -
    Property NameMeaningRequired
    fs.swift.service.PROVIDER.auth.url
    +
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index f8f98ca54425d..4b93fb7c89ad1 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -433,7 +433,7 @@ Streaming core artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, some of the common ones are as follows. -
    Receiver Type
    +
    @@ -820,7 +820,7 @@ Similar to that of RDDs, transformations allow the data from the input DStream t DStreams support many of the transformations available on normal Spark RDD's. Some of the common ones are as follows. -
    SourceArtifact
    Kafka spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}}
    Kinesis
    spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Amazon Software License]
    +
    @@ -1109,7 +1109,7 @@ JavaPairDStream windowedWordCounts = pairs.reduceByKeyAndWindow Some of the common window operations are as follows. All of these operations take the said two parameters - windowLength and slideInterval. -
    TransformationMeaning
    map(func)
    +
    @@ -1280,7 +1280,7 @@ Since the output operations actually allow the transformed data to be consumed b they trigger the actual execution of all the DStream transformations (similar to actions for RDDs). Currently, the following output operations are defined: -
    TransformationMeaning
    window(windowLength, slideInterval)
    +
    @@ -2485,7 +2485,7 @@ enabled](#deploying-applications) and reliable receivers, there is zero data los The following table summarizes the semantics under failures: -
    Output OperationMeaning
    print()
    +
    diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 66e6efb1c8a9f..c5ffdf025b173 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -297,7 +297,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); Each row in the source has the following schema: -
    Deployment Scenario
    +
    @@ -336,7 +336,7 @@ Each row in the source has the following schema: The following options must be set for the Kafka source for both batch and streaming queries. -
    ColumnType
    key
    +
    @@ -368,7 +368,7 @@ for both batch and streaming queries. The following configurations are optional: -
    Optionvaluemeaning
    assign
    +
    @@ -607,7 +607,7 @@ The caching key is built up from the following information: The following properties are available to configure the consumer pool: -
    Optionvaluedefaultquery typemeaning
    startingTimestamp
    +
    @@ -657,7 +657,7 @@ Note that it doesn't leverage Apache Commons Pool due to the difference of chara The following properties are available to configure the fetched data pool: -
    Property NameDefaultMeaningSince Version
    spark.kafka.consumer.cache.capacity
    +
    @@ -685,7 +685,7 @@ solution to remove duplicates when reading the written data could be to introduc that can be used to perform de-duplication when reading. The Dataframe being written to Kafka should have the following columns in schema: -
    Property NameDefaultMeaningSince Version
    spark.kafka.consumer.fetchedData.cache.timeout
    +
    @@ -725,7 +725,7 @@ will be used. The following options must be set for the Kafka sink for both batch and streaming queries. -
    ColumnType
    key (optional)
    +
    @@ -736,7 +736,7 @@ for both batch and streaming queries. The following configurations are optional: -
    Optionvaluemeaning
    kafka.bootstrap.servers
    +
    @@ -912,7 +912,7 @@ It will use different Kafka producer when delegation token is renewed; Kafka pro The following properties are available to configure the producer pool: -
    Optionvaluedefaultquery typemeaning
    topic
    +
    @@ -1039,7 +1039,7 @@ When none of the above applies then unsecure connection assumed. Delegation tokens can be obtained from multiple clusters and ${cluster} is an arbitrary unique identifier which helps to group different configurations. -
    Property NameDefaultMeaningSince Version
    spark.kafka.producer.cache.timeout
    +
    diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 76a22621a0e32..845f0617898b4 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -545,7 +545,7 @@ checkpointed offsets after a failure. See the earlier section on [fault-tolerance semantics](#fault-tolerance-semantics). Here are the details of all the sources in Spark. -
    Property NameDefaultMeaningSince Version
    spark.kafka.clusters.${cluster}.auth.bootstrap.servers
    +
    @@ -1819,7 +1819,7 @@ regarding watermark delays and whether data will be dropped or not. ##### Support matrix for joins in streaming queries -
    Source
    +
    @@ -2307,7 +2307,7 @@ to `org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider`. Here are the configs regarding to RocksDB instance of the state store provider: -
    Left Input
    +
    @@ -2474,7 +2474,7 @@ More information to be added in future releases. Different types of streaming queries support different output modes. Here is the compatibility matrix. -
    Config Name
    +
    @@ -2613,7 +2613,7 @@ meant for debugging purposes only. See the earlier section on [fault-tolerance semantics](#fault-tolerance-semantics). Here are the details of all the sinks in Spark. -
    Query Type
    +
    @@ -3201,7 +3201,7 @@ The trigger settings of a streaming query define the timing of streaming data pr the query is going to be executed as micro-batch query with a fixed batch interval or as a continuous processing query. Here are the different kinds of triggers that are supported. -
    Sink
    +
    @@ -3831,10 +3831,10 @@ class Listener(StreamingQueryListener): print("Query started: " + queryStarted.id) def onQueryProgress(self, event): - println("Query terminated: " + queryTerminated.id) + print("Query made progress: " + queryProgress.progress) def onQueryTerminated(self, event): - println("Query made progress: " + queryProgress.progress) + print("Query terminated: " + queryTerminated.id) spark.streams.addListener(Listener()) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index becdfb4b18f5d..4821f883eef9d 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -159,7 +159,7 @@ export HADOOP_CONF_DIR=XXX The master URL passed to Spark can be in one of the following formats: -
    Trigger Type
    +
    diff --git a/docs/web-ui.md b/docs/web-ui.md index 079bc6137f020..cdf62e0d8ec0b 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -380,7 +380,7 @@ operator shows the number of bytes written by a shuffle. Here is the list of SQL metrics: -
    Master URLMeaning
    local Run Spark locally with one worker thread (i.e. no parallelism at all).
    local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine).
    +
    diff --git a/examples/pom.xml b/examples/pom.xml index c95269dbc4bb0..26d91eff504f2 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 30fe05957d569..c4f250b40f33d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index f67475ac11bc0..a47d25015dfa9 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index d3c52a713911a..5c1844be5782d 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 172fb8c560876..ff729cd1cb6dc 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -21,6 +21,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; /** * Helper methods for command builders. @@ -30,6 +33,11 @@ class CommandBuilderUtils { static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; + // This should be consistent with org.apache.spark.internal.config.SECRET_REDACTION_PATTERN + // We maintain this copy to avoid depending on `core` module. + static final String SECRET_REDACTION_PATTERN = "(?i)secret|password|token|access[.]?key"; + static final Pattern redactPattern = Pattern.compile(SECRET_REDACTION_PATTERN); + static final Pattern keyValuePattern = Pattern.compile("-D(.+?)=(.+)"); /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { @@ -328,4 +336,23 @@ static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfN return libdir.getAbsolutePath(); } + /** + * Redact a command-line argument's value part which matches `-Dkey=value` pattern. + * Note that this should be consistent with `org.apache.spark.util.Utils.redactCommandLineArgs`. + */ + static List redactCommandLineArgs(List args) { + return args.stream().map(CommandBuilderUtils::redact).collect(Collectors.toList()); + } + + /** + * Redact a command-line argument's value part which matches `-Dkey=value` pattern. + */ + static String redact(String arg) { + Matcher m = keyValuePattern.matcher(arg); + if (m.find() && redactPattern.matcher(m.group(1)).find()) { + return String.format("-D%s=%s", m.group(1), "*********(redacted)"); + } else { + return arg; + } + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java b/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java index 013dde2766f49..f6a9607e7c5d3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java +++ b/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java @@ -36,6 +36,7 @@ public class JavaModuleOptions { "--add-opens=java.base/java.util=ALL-UNNAMED", "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", + "--add-opens=java.base/jdk.internal.ref=ALL-UNNAMED", "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", "--add-opens=java.base/sun.security.action=ALL-UNNAMED", diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 6501fc1764c25..321fca0912704 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -114,7 +114,7 @@ private static List buildCommand( boolean printLaunchCommand) throws IOException, IllegalArgumentException { List cmd = builder.buildCommand(env); if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("Spark Command: " + join(" ", redactCommandLineArgs(cmd))); System.err.println("========================================"); } return cmd; diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index 46cdffc190d52..1b2c683880c25 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -68,6 +68,16 @@ public void testInvalidOptionStrings() { testInvalidOpt("'abcde"); } + @Test + public void testRedactCommandLineArgs() { + assertEquals(redact("secret"), "secret"); + assertEquals(redact("-Dk=v"), "-Dk=v"); + assertEquals(redact("-Dk=secret"), "-Dk=secret"); + assertEquals(redact("-DsecretKey=my-secret"), "-DsecretKey=*********(redacted)"); + assertEquals(redactCommandLineArgs(Arrays.asList("-DsecretKey=my-secret")), + Arrays.asList("-DsecretKey=*********(redacted)")); + } + @Test public void testWindowsBatchQuoting() { assertEquals("abc", quoteForBatchScript("abc")); diff --git a/licenses/LICENSE-copybutton.txt b/licenses/LICENSE-copybutton.txt deleted file mode 100644 index 45be6b83a53be..0000000000000 --- a/licenses/LICENSE-copybutton.txt +++ /dev/null @@ -1,49 +0,0 @@ -PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 --------------------------------------------- - -1. This LICENSE AGREEMENT is between the Python Software Foundation -("PSF"), and the Individual or Organization ("Licensee") accessing and -otherwise using this software ("Python") in source or binary form and -its associated documentation. - -2. Subject to the terms and conditions of this License Agreement, PSF hereby -grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, -analyze, test, perform and/or display publicly, prepare derivative works, -distribute, and otherwise use Python alone or in any derivative version, -provided, however, that PSF's License Agreement and PSF's notice of copyright, -i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, -2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019 Python Software Foundation; -All Rights Reserved" are retained in Python alone or in any derivative version -prepared by Licensee. - -3. In the event Licensee prepares a derivative work that is based on -or incorporates Python or any part thereof, and wants to make -the derivative work available to others as provided herein, then -Licensee hereby agrees to include in any such work a brief summary of -the changes made to Python. - -4. PSF is making Python available to Licensee on an "AS IS" -basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR -IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND -DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS -FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT -INFRINGE ANY THIRD PARTY RIGHTS. - -5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON -FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS -A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, -OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. - -6. This License Agreement will automatically terminate upon a material -breach of its terms and conditions. - -7. Nothing in this License Agreement shall be deemed to create any -relationship of agency, partnership, or joint venture between PSF and -Licensee. This License Agreement does not grant permission to use PSF -trademarks or trade name in a trademark sense to endorse or promote -products or services of Licensee, or any third party. - -8. By copying, installing or otherwise using Python, Licensee -agrees to be bound by the terms and conditions of this License -Agreement. - diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 5ec981a7816be..bb821190273e1 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml @@ -37,7 +37,19 @@ org.scalanlp breeze_${scala.binary.version} + + + org.scala-lang.modules + scala-collection-compat_${scala.binary.version} + + + org.apache.commons commons-math3 diff --git a/mllib/pom.xml b/mllib/pom.xml index fe7c3da9c4eb2..202b80d38e24f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml @@ -96,10 +96,20 @@ org.scala-lang.modules scala-parallel-collections_${scala.binary.version} + + org.scala-lang.modules + scala-collection-compat_${scala.binary.version} + --> org.scalanlp breeze_${scala.binary.version} + + + org.scala-lang.modules + scala-collection-compat_${scala.binary.version} + + org.apache.commons diff --git a/pom.xml b/pom.xml index 93d696d494e84..3d9b003bd19c8 100644 --- a/pom.xml +++ b/pom.xml @@ -26,13 +26,13 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT pom Spark Project Parent POM https://spark.apache.org/ - Apache 2.0 License + Apache-2.0 http://www.apache.org/licenses/LICENSE-2.0.html repo @@ -115,7 +115,7 @@ 1.8 ${java.version} ${java.version} - 3.8.8 + 3.9.6 3.1.0 spark 9.5 @@ -141,9 +141,9 @@ 10.14.2.0 1.13.1 - 1.9.1 + 1.9.4 shaded-protobuf - 9.4.52.v20230823 + 9.4.54.v20240208 4.0.3 0.10.0 4.8.0 @@ -186,11 +187,11 @@ 1.9.13 2.15.2 2.15.2 - 1.1.10.3 + 1.1.10.5 3.0.3 - 1.16.0 + 1.16.1 1.23.0 - 2.13.0 + 2.16.1 2.6 @@ -218,7 +219,7 @@ 3.1.0 1.1.0 1.5.0 - 1.70 + 1.77 1.9.0 4.1.96.Final org.fusesource.leveldbjni 6.7.2 + 1.17.6 ${java.home} @@ -308,6 +310,7 @@ --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED + --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED @@ -624,6 +627,16 @@ org.apache.commons commons-compress ${commons-compress.version} + + + commons-io + commons-io + + + org.apache.commons + commons-lang3 + + org.apache.commons @@ -1091,6 +1104,11 @@ scala-xml_${scala.binary.version} 2.1.0 + + org.scala-lang.modules + scala-collection-compat_${scala.binary.version} + ${scala-collection-compat.version} + org.scala-lang scala-compiler @@ -1145,6 +1163,12 @@ selenium-4-9_${scala.binary.version} 3.2.16.0 test + + + org.seleniumhq.selenium + htmlunit-driver + + org.mockito @@ -1195,22 +1219,31 @@ test - com.spotify - docker-client - 8.14.1 + com.github.docker-java + docker-java + 3.3.4 test - shaded - - guava - com.google.guava - commons-logging commons-logging + + com.github.docker-java + docker-java-transport-netty + + + com.github.docker-java + docker-java-transport-jersey + + + com.github.docker-java + docker-java-transport-zerodep + 3.3.4 + test + com.mysql mysql-connector-j @@ -1226,7 +1259,7 @@ org.postgresql postgresql - 42.6.0 + 42.7.2 test @@ -1427,13 +1460,13 @@ org.bouncycastle - bcprov-jdk15on + bcprov-jdk18on ${bouncycastle.version} test org.bouncycastle - bcpkix-jdk15on + bcpkix-jdk18on ${bouncycastle.version} test @@ -1451,6 +1484,16 @@ org.apache.avro avro ${avro.version} + + + commons-io + commons-io + + + org.apache.commons + commons-lang3 + + org.apache.avro @@ -1490,6 +1533,14 @@ com.github.luben zstd-jni + + commons-io + commons-io + + + org.apache.commons + commons-lang3 + src false @@ -3057,8 +3117,6 @@ ${spark.test.docker.removePulledImage} __not_used__ - - test:/// ${test.exclude.tags},${test.default.exclude.tags} ${test.include.tags} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9805ad7f09d6e..ae026165addc1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -41,11 +41,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.SQLUserDefinedType"), // [SPARK-43165][SQL] Move canWrite to DataTypeUtils ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.canWrite"), - // [SPARK-43195][CORE] Remove unnecessary serializable wrapper in HadoopFSUtils - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.HadoopFSUtils$SerializableBlockLocation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.HadoopFSUtils$SerializableBlockLocation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.HadoopFSUtils$SerializableFileStatus"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.HadoopFSUtils$SerializableFileStatus$"), // [SPARK-43792][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listCatalogs ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listCatalogs"), // [SPARK-43881][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listDatabases @@ -72,7 +67,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.java.function.MapGroupsWithStateFunction"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SaveMode"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState"), + // [SPARK-46480][CORE][SQL] Fix NPE when table cache task attempt + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.isFailed") ) // Default exclude rules diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 563d53577548e..f8659a4f4a257 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -91,7 +91,7 @@ object BuildCommons { // SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`. val protoVersion = "3.23.4" // GRPC version used for Spark Connect. - val gprcVersion = "1.56.0" + val grpcVersion = "1.56.0" } object SparkBuild extends PomBuild { @@ -160,16 +160,21 @@ object SparkBuild extends PomBuild { val replacements = Map( """customId="println" level="error"""" -> """customId="println" level="warn"""" ) - var contents = Source.fromFile(in).getLines.mkString("\n") - for ((k, v) <- replacements) { - require(contents.contains(k), s"Could not rewrite '$k' in original scalastyle config.") - contents = contents.replace(k, v) - } - new PrintWriter(out) { - write(contents) - close() + val source = Source.fromFile(in) + try { + var contents = source.getLines.mkString("\n") + for ((k, v) <- replacements) { + require(contents.contains(k), s"Could not rewrite '$k' in original scalastyle config.") + contents = contents.replace(k, v) + } + new PrintWriter(out) { + write(contents) + close() + } + out + } finally { + source.close() } - out } // Return a cached scalastyle task for a given configuration (usually Compile or Test) @@ -464,8 +469,7 @@ object SparkBuild extends PomBuild { /* Protobuf settings */ enable(SparkProtobuf.settings)(protobuf) - // SPARK-14738 - Remove docker tests from main Spark build - // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) + enable(DockerIntegrationTests.settings)(dockerIntegrationTests) if (!profiles.contains("volcano")) { enable(Volcano.settings)(kubernetes) @@ -567,7 +571,6 @@ object SparkParallelTestGrouping { "org.apache.spark.sql.catalyst.expressions.MathExpressionsSuite", "org.apache.spark.sql.hive.HiveExternalCatalogSuite", "org.apache.spark.sql.hive.StatisticsSuite", - "org.apache.spark.sql.hive.client.VersionsSuite", "org.apache.spark.sql.hive.client.HiveClientVersions", "org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite", "org.apache.spark.ml.classification.LogisticRegressionSuite", @@ -580,6 +583,7 @@ object SparkParallelTestGrouping { "org.apache.spark.sql.hive.thriftserver.ui.ThriftServerPageSuite", "org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2ListenerSuite", "org.apache.spark.sql.kafka010.KafkaDelegationTokenSuite", + "org.apache.spark.sql.streaming.RocksDBStateStoreStreamingAggregationSuite", "org.apache.spark.shuffle.KubernetesLocalDiskShuffleDataIOSuite", "org.apache.spark.sql.hive.HiveScalaReflectionSuite" ) @@ -693,7 +697,7 @@ object SparkConnectCommon { SbtPomKeys.effectivePom.value.getProperties.get( "guava.failureaccess.version").asInstanceOf[String] Seq( - "io.grpc" % "protoc-gen-grpc-java" % BuildCommons.gprcVersion asProtocPlugin(), + "io.grpc" % "protoc-gen-grpc-java" % BuildCommons.grpcVersion asProtocPlugin(), "com.google.guava" % "guava" % guavaVersion, "com.google.guava" % "failureaccess" % guavaFailureaccessVersion, "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" @@ -1401,7 +1405,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/io"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) - .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect/"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/internal"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive"))) @@ -1529,7 +1533,11 @@ object CopyDependencies { if (destJar.isFile()) { destJar.delete() } - if (jar.getName.contains("spark-connect") && + + if (jar.getName.contains("spark-connect-common") && + !SbtPomKeys.profiles.value.contains("noshade-connect")) { + // Don't copy the spark connect common JAR as it is shaded in the spark connect. + } else if (jar.getName.contains("spark-connect") && !SbtPomKeys.profiles.value.contains("noshade-connect")) { Files.copy(fid.toPath, destJar.toPath) } else if (jar.getName.contains("connect-client-jvm") && @@ -1591,7 +1599,6 @@ object TestSettings { (Test / javaOptions) += "-Dspark.ui.enabled=false", (Test / javaOptions) += "-Dspark.ui.showConsoleProgress=false", (Test / javaOptions) += "-Dspark.unsafe.exceptionOnMemoryLeak=true", - (Test / javaOptions) += "-Dspark.hadoop.hadoop.security.key.provider.path=test:///", (Test / javaOptions) += "-Dhive.conf.validation=false", (Test / javaOptions) += "-Dsun.io.serialization.extendedDebugInfo=false", (Test / javaOptions) += "-Dderby.system.durability=test", @@ -1619,6 +1626,7 @@ object TestSettings { "--add-opens=java.base/java.util=ALL-UNNAMED", "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", + "--add-opens=java.base/jdk.internal.ref=ALL-UNNAMED", "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", "--add-opens=java.base/sun.security.action=ALL-UNNAMED", diff --git a/python/docs/source/_static/copybutton.js b/python/docs/source/_static/copybutton.js deleted file mode 100644 index 896faad3f9df1..0000000000000 --- a/python/docs/source/_static/copybutton.js +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2014 PSF. Licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 -// File originates from the cpython source found in Doc/tools/sphinxext/static/copybutton.js - -$(document).ready(function() { - /* Add a [>>>] button on the top-right corner of code samples to hide - * the >>> and ... prompts and the output and thus make the code - * copyable. */ - var div = $('.highlight-python .highlight,' + - '.highlight-default .highlight,' + - '.highlight-python3 .highlight') - var pre = div.find('pre'); - - // get the styles from the current theme - pre.parent().parent().css('position', 'relative'); - var hide_text = 'Hide the prompts and output'; - var show_text = 'Show the prompts and output'; - var border_width = pre.css('border-top-width'); - var border_style = pre.css('border-top-style'); - var border_color = pre.css('border-top-color'); - var button_styles = { - 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', - 'border-color': border_color, 'border-style': border_style, - 'border-width': border_width, 'color': border_color, 'text-size': '75%', - 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', - 'border-radius': '0 3px 0 0', - 'user-select': 'none' - } - - // create and add the button to all the code blocks that contain >>> - div.each(function(index) { - var jthis = $(this); - if (jthis.find('.gp').length > 0) { - var button = $('>>>'); - button.css(button_styles) - button.attr('title', hide_text); - button.data('hidden', 'false'); - jthis.prepend(button); - } - // tracebacks (.gt) contain bare text elements that need to be - // wrapped in a span to work with .nextUntil() (see later) - jthis.find('pre:has(.gt)').contents().filter(function() { - return ((this.nodeType == 3) && (this.data.trim().length > 0)); - }).wrap(''); - }); - - // define the behavior of the button when it's clicked - $('.copybutton').click(function(e){ - e.preventDefault(); - var button = $(this); - if (button.data('hidden') === 'false') { - // hide the code output - button.parent().find('.go, .gp, .gt').hide(); - button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); - button.css('text-decoration', 'line-through'); - button.attr('title', show_text); - button.data('hidden', 'true'); - } else { - // show the code output - button.parent().find('.go, .gp, .gt').show(); - button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); - button.css('text-decoration', 'none'); - button.attr('title', hide_text); - button.data('hidden', 'false'); - } - }); -}); - diff --git a/python/docs/source/_static/css/pyspark.css b/python/docs/source/_static/css/pyspark.css index 89b7c65f27a51..ccfe60f2bca64 100644 --- a/python/docs/source/_static/css/pyspark.css +++ b/python/docs/source/_static/css/pyspark.css @@ -95,3 +95,16 @@ u.bd-sidebar .nav>li>ul>.active:hover>a,.bd-sidebar .nav>li>ul>.active>a { .spec_table tr, td, th { border-top: none!important; } + +/* Styling to the version dropdown */ +#version-button { + padding-left: 0.2rem; + padding-right: 3.2rem; +} + +#version_switcher { + height: auto; + max-height: 300px; + width: 165px; + overflow-y: auto; +} diff --git a/python/docs/source/_templates/version-switcher.html b/python/docs/source/_templates/version-switcher.html new file mode 100644 index 0000000000000..16c443229f4be --- /dev/null +++ b/python/docs/source/_templates/version-switcher.html @@ -0,0 +1,77 @@ + + + + + diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py index 38c331048e7b6..1b5cf34744651 100644 --- a/python/docs/source/conf.py +++ b/python/docs/source/conf.py @@ -63,6 +63,7 @@ 'sphinx.ext.viewcode', 'sphinx.ext.mathjax', 'sphinx.ext.autosummary', + 'sphinx_copybutton', 'nbsphinx', # Converts Jupyter Notebook to reStructuredText files for Sphinx. # For ipython directive in reStructuredText files. It is generated by the notebook. 'IPython.sphinxext.ipython_console_highlighting', @@ -70,6 +71,9 @@ 'sphinx_plotly_directive', # For visualize plot result ] +# sphinx copy button +copybutton_exclude = '.linenos, .gp, .go' + # plotly plot directive plotly_include_source = True plotly_html_show_formats = False @@ -94,9 +98,9 @@ .. |examples| replace:: Examples .. _examples: https://github.com/apache/spark/tree/{0}/examples/src/main/python .. |downloading| replace:: Downloading -.. _downloading: https://spark.apache.org/docs/{1}/building-spark.html +.. _downloading: https://spark.apache.org/docs/{1}/#downloading .. |building_spark| replace:: Building Spark -.. _building_spark: https://spark.apache.org/docs/{1}/#downloading +.. _building_spark: https://spark.apache.org/docs/{1}/building-spark.html """.format( os.environ.get("GIT_HASH", "master"), os.environ.get("RELEASE_VERSION", "latest"), @@ -177,10 +181,21 @@ # a list of builtin themes. html_theme = 'pydata_sphinx_theme' +html_context = { + # When releasing a new Spark version, please update the file + # "site/static/versions.json" under the code repository "spark-website" + # (item should be added in order), and also set the local environment + # variable "RELEASE_VERSION". + "switcher_json_url": "https://spark.apache.org/static/versions.json", + "switcher_template_url": "https://spark.apache.org/docs/{version}/api/python/index.html", +} + # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +html_theme_options = { + "navbar_end": ["version-switcher"] +} # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] @@ -409,9 +424,6 @@ # If false, no index is generated. #epub_use_index = True -def setup(app): - # The app.add_javascript() is deprecated. - getattr(app, "add_js_file", getattr(app, "add_javascript", None))('copybutton.js') # Skip sample endpoint link (not expected to resolve) linkcheck_ignore = [r'https://kinesis.us-east-1.amazonaws.com'] diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 6822285e96172..e97632a8b384b 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -157,7 +157,7 @@ Package Supported version Note ========================== ========================= ====================================================================================== `py4j` >=0.10.9.7 Required `pandas` >=1.0.5 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL -`pyarrow` >=4.0.0 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL +`pyarrow` >=4.0.0,<13.0.0 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL `numpy` >=1.15 Required for pandas API on Spark and MLLib DataFrame-based API; Optional for Spark SQL `grpcio` >=1.48,<1.57 Required for Spark Connect `grpcio-status` >=1.48,<1.57 Required for Spark Connect diff --git a/python/docs/source/getting_started/quickstart_connect.ipynb b/python/docs/source/getting_started/quickstart_connect.ipynb index 15a2ab749d2a6..0397a0ebf5071 100644 --- a/python/docs/source/getting_started/quickstart_connect.ipynb +++ b/python/docs/source/getting_started/quickstart_connect.ipynb @@ -28,7 +28,9 @@ "metadata": {}, "outputs": [], "source": [ - "!$HOME/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:$SPARK_VERSION" + "%%bash\n", + "source ~/.profile # Make sure environment variables are loaded.\n", + "$HOME/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:$SPARK_VERSION" ] }, { diff --git a/python/pyspark/ml/connect/functions.py b/python/pyspark/ml/connect/functions.py index ab7e3ab3c9adc..d8aa54dcf9bee 100644 --- a/python/pyspark/ml/connect/functions.py +++ b/python/pyspark/ml/connect/functions.py @@ -39,6 +39,7 @@ def array_to_vector(col: Column) -> Column: def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -54,7 +55,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("ml.connect.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index f3e621c19f0f0..f0d60a117e12f 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -15,12 +15,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import unittest +import os + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin -have_torch = True +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: @@ -31,7 +32,7 @@ class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py b/python/pyspark/ml/tests/connect/test_connect_evaluation.py index ce7cf03049d3c..58076dfe0bbe6 100644 --- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_evaluation import EvaluationTestsMixin -have_torcheval = True +have_torcheval = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torcheval # noqa: F401 except ImportError: @@ -29,7 +30,9 @@ @unittest.skipIf(not have_torcheval, "torcheval is required") class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py b/python/pyspark/ml/tests/connect/test_connect_feature.py index d7698c3772201..49021f6e82c5a 100644 --- a/python/pyspark/ml/tests/connect/test_connect_feature.py +++ b/python/pyspark/ml/tests/connect/test_connect_feature.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py b/python/pyspark/ml/tests/connect/test_connect_function.py index 7da3d3f1addd8..fc3344ecebfe2 100644 --- a/python/pyspark/ml/tests/connect/test_connect_function.py +++ b/python/pyspark/ml/tests/connect/test_connect_function.py @@ -33,6 +33,7 @@ from pyspark.ml.connect import functions as CF +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectMLFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index e676c8bfee955..eb2bedddbe283 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -15,16 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py b/python/pyspark/ml/tests/connect/test_connect_summarizer.py index 0b0537dfee3cd..28cfa4b4dc1b3 100644 --- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py +++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_summarizer import SummarizerTestsMixin class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py index 18673d4b26be9..901367e44d20b 100644 --- a/python/pyspark/ml/tests/connect/test_connect_tuning.py +++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py @@ -15,16 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_tuning import CrossValidatorTestsMixin +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py index 18556633d89f8..60f683bf726ca 100644 --- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py @@ -15,10 +15,11 @@ # limitations under the License. # +import os import unittest from pyspark.sql import SparkSession -have_torch = True +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py index b855332f96c42..238775ded2a21 100644 --- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py @@ -19,7 +19,7 @@ import shutil import unittest -have_torch = True +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: @@ -81,7 +81,7 @@ def _get_inputs_for_test_local_training_succeeds(self): ] -@unittest.skipIf(not have_torch, "torch is required") +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class TorchDistributorLocalUnitTestsIIOnConnect( TorchDistributorLocalUnitTestsMixin, unittest.TestCase ): diff --git a/python/pyspark/ml/torch/tests/test_data_loader.py b/python/pyspark/ml/torch/tests/test_data_loader.py index 67ab6e378ceaa..f7814f8195416 100644 --- a/python/pyspark/ml/torch/tests/test_data_loader.py +++ b/python/pyspark/ml/torch/tests/test_data_loader.py @@ -15,10 +15,11 @@ # limitations under the License. # +import os import numpy as np import unittest -have_torch = True +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index dd93e31d0235e..74e0b328e4dfb 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -815,7 +815,7 @@ def symmetric_difference( # type: ignore[override] sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other)) if sort: - sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_columns) + sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names) internal = InternalFrame( spark_frame=sdf_symdiff, diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 95ca92e78787d..b54ae88616fa5 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -5910,7 +5910,7 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: # then return monotonically_increasing_id. This will let max by # to return last index value, which is the behaviour of pandas else spark_column.isNotNull(), - monotonically_increasing_id_column, + F.col(monotonically_increasing_id_column), ), ) for index in where diff --git a/python/pyspark/pandas/supported_api_gen.py b/python/pyspark/pandas/supported_api_gen.py index 06591c5b26ad6..8c3cdec3671c1 100644 --- a/python/pyspark/pandas/supported_api_gen.py +++ b/python/pyspark/pandas/supported_api_gen.py @@ -138,23 +138,11 @@ def _create_supported_by_module( # module not implemented return {} - pd_funcs = dict( - [ - m - for m in getmembers(pd_module, isfunction) - if not m[0].startswith("_") and m[0] in pd_module.__dict__ - ] - ) + pd_funcs = dict([m for m in getmembers(pd_module, isfunction) if not m[0].startswith("_")]) if not pd_funcs: return {} - ps_funcs = dict( - [ - m - for m in getmembers(ps_module, isfunction) - if not m[0].startswith("_") and m[0] in ps_module.__dict__ - ] - ) + ps_funcs = dict([m for m in getmembers(ps_module, isfunction) if not m[0].startswith("_")]) return _organize_by_implementation_status( module_name, pd_funcs, ps_funcs, pd_module_group, ps_module_group diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index 6e41395186d3d..012eabf958eb8 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -792,9 +792,21 @@ def _new_type_holders( isinstance(param, slice) and param.step is None and param.stop is not None for param in params ) - is_unnamed_params = all( - not isinstance(param, slice) and not isinstance(param, Iterable) for param in params - ) + if sys.version_info < (3, 11): + is_unnamed_params = all( + not isinstance(param, slice) and not isinstance(param, Iterable) for param in params + ) + else: + # PEP 646 changes `GenericAlias` instances into iterable ones at Python 3.11 + is_unnamed_params = all( + not isinstance(param, slice) + and ( + not isinstance(param, Iterable) + or isinstance(param, typing.GenericAlias) + or isinstance(param, typing._GenericAlias) + ) + for param in params + ) if is_named_params: # DataFrame["id": int, "A": int] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ea9a31022298..aa63c6509dce8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1120,7 +1120,7 @@ def takeSample( Parameters ---------- - withReplacement : list + withReplacement : bool whether sampling is done with replacement num : int size of the returned sample diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py index bf019ef8fe7d7..821660fdbd302 100644 --- a/python/pyspark/sql/connect/avro/functions.py +++ b/python/pyspark/sql/connect/avro/functions.py @@ -85,7 +85,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.avro.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index 2a54a0d727af9..069a8d013ff32 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -326,6 +326,7 @@ def registerFunction( def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -333,7 +334,9 @@ def _test() -> None: globs = pyspark.sql.connect.catalog.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.catalog tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.catalog tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 7b3299d123b97..7b1aafbefebbe 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1005,6 +1005,7 @@ def close(self) -> None: """ Close the channel. """ + ExecutePlanResponseReattachableIterator.shutdown() self._channel.close() self._closed = True diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 7e1e722d5fd8a..6addb5bd2c652 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -18,10 +18,11 @@ check_dependencies(__name__) +from threading import RLock import warnings import uuid from collections.abc import Generator -from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast +from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar from multiprocessing.pool import ThreadPool import os @@ -53,7 +54,30 @@ class ExecutePlanResponseReattachableIterator(Generator): ReleaseExecute RPCs that instruct the server to release responses that it already processed. """ - _release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8) + # Lock to manage the pool + _lock: ClassVar[RLock] = RLock() + _release_thread_pool: Optional[ThreadPool] = ThreadPool(os.cpu_count() if os.cpu_count() else 8) + + @classmethod + def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: + """ + When the channel is closed, this method will be called before, to make sure all + outstanding calls are closed. + """ + with cls._lock: + if cls._release_thread_pool is not None: + cls._release_thread_pool.close() + cls._release_thread_pool.join() + cls._release_thread_pool = None + + @classmethod + def _initialize_pool_if_necessary(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: + """ + If the processing pool for the release calls is None, initialize the pool exactly once. + """ + with cls._lock: + if cls._release_thread_pool is None: + cls._release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8) def __init__( self, @@ -62,6 +86,7 @@ def __init__( retry_policy: Dict[str, Any], metadata: Iterable[Tuple[str, str]], ): + ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary() self._request = request self._retry_policy = retry_policy if request.operation_id: @@ -111,7 +136,6 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse: self._last_returned_response_id = ret.response_id if ret.HasField("result_complete"): - self._result_complete = True self._release_all() else: self._release_until(self._last_returned_response_id) @@ -190,7 +214,8 @@ def target() -> None: except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.") - ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + if ExecutePlanResponseReattachableIterator._release_thread_pool is not None: + ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) def _release_all(self) -> None: """ @@ -218,7 +243,8 @@ def target() -> None: except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.") - ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + if ExecutePlanResponseReattachableIterator._release_thread_pool is not None: + ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) self._result_complete = True def _call_iter(self, iter_fun: Callable) -> Any: diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 0529293816338..464f5397b85b6 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -483,6 +483,7 @@ def __nonzero__(self) -> None: def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -490,7 +491,9 @@ def _test() -> None: globs = pyspark.sql.connect.column.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.column tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.column tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/conf.py b/python/pyspark/sql/connect/conf.py index d323de716c46a..cb296a750e62f 100644 --- a/python/pyspark/sql/connect/conf.py +++ b/python/pyspark/sql/connect/conf.py @@ -97,6 +97,7 @@ def _checkType(self, obj: Any, identifier: str) -> None: def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -104,7 +105,9 @@ def _test() -> None: globs = pyspark.sql.connect.conf.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.conf tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.conf tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 7b326538a8e0a..6f23a15fb4ad1 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -687,7 +687,7 @@ def sample( if withReplacement is None: withReplacement = False - seed = int(seed) if seed is not None else None + seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) return DataFrame.withPlan( plan.Sample( @@ -2150,7 +2150,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.dataframe tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index e2583f84c417c..ecb800bbee93d 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -3906,6 +3906,7 @@ def call_function(funcName: str, *cols: "ColumnOrName") -> Column: def _test() -> None: import sys + import os import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.functions @@ -3914,7 +3915,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index a393d2cb37e89..2d5a66fd6ef92 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -388,6 +388,7 @@ def _extract_cols(gd: "GroupedData") -> List[Column]: def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -396,7 +397,9 @@ def _test() -> None: globs = pyspark.sql.connect.group.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.group tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.group tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 84fd013d0144a..43af8bb427a5a 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -398,9 +398,6 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() clr = plan.cached_local_relation - if session._user_id: - clr.userId = session._user_id - clr.sessionId = session._session_id clr.hash = self._hash return plan @@ -1199,6 +1196,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() + plan.common.plan_id = self._child._plan_id plan.collect_metrics.input.CopyFrom(self._child.plan(session)) plan.collect_metrics.name = self._name plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs]) @@ -1657,16 +1655,16 @@ def command(self, session: "SparkConnectClient") -> proto.Command: plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE elif wm == "overwrite": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE - elif wm == "overwrite_partition": + if self.overwrite_condition is not None: + plan.write_operation_v2.overwrite_condition.CopyFrom( + self.col_to_expr(self.overwrite_condition, session) + ) + elif wm == "overwrite_partitions": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS elif wm == "append": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_APPEND elif wm == "replace": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_REPLACE - if self.overwrite_condition is not None: - plan.write_operation_v2.overwrite_condition.CopyFrom( - self.col_to_expr(self.overwrite_condition, session) - ) elif wm == "create_or_replace": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE else: @@ -2125,7 +2123,9 @@ def __init__( self._input_grouping_cols = input_grouping_cols self._other_grouping_cols = other_grouping_cols self._other = cast(LogicalPlan, other) - self._func = function._build_common_inline_user_defined_function(*cols) + # The function takes entire DataFrame as inputs, no need to do + # column binding (no input columns). + self._func = function._build_common_inline_user_defined_function() def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 3a0a7ff71fd3b..3f7e57949373b 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xe1\x18\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"[\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id"\xe7\x01\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x30\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryR\x04\x61rgs\x12<\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x07posArgs\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xc6\x04\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x81\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"_\n\x13\x43\x61\x63hedLocalRelation\x12\x16\n\x06userId\x18\x01 \x01(\tR\x06userId\x12\x1c\n\tsessionId\x18\x02 \x01(\tR\tsessionId\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hash"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xef\x01\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x65\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xb5\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x42\r\n\x0b_is_barrier"\xfb\x04\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_conf"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schemaB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xe1\x18\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"[\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id"\xe7\x01\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x30\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryR\x04\x61rgs\x12<\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x07posArgs\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xc6\x04\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x81\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xef\x01\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x65\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xb5\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x42\r\n\x0b_is_barrier"\xfb\x04\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_conf"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schemaB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -111,85 +111,85 @@ _LOCALRELATION._serialized_start = 7090 _LOCALRELATION._serialized_end = 7179 _CACHEDLOCALRELATION._serialized_start = 7181 - _CACHEDLOCALRELATION._serialized_end = 7276 - _CACHEDREMOTERELATION._serialized_start = 7278 - _CACHEDREMOTERELATION._serialized_end = 7333 - _SAMPLE._serialized_start = 7336 - _SAMPLE._serialized_end = 7609 - _RANGE._serialized_start = 7612 - _RANGE._serialized_end = 7757 - _SUBQUERYALIAS._serialized_start = 7759 - _SUBQUERYALIAS._serialized_end = 7873 - _REPARTITION._serialized_start = 7876 - _REPARTITION._serialized_end = 8018 - _SHOWSTRING._serialized_start = 8021 - _SHOWSTRING._serialized_end = 8163 - _HTMLSTRING._serialized_start = 8165 - _HTMLSTRING._serialized_end = 8279 - _STATSUMMARY._serialized_start = 8281 - _STATSUMMARY._serialized_end = 8373 - _STATDESCRIBE._serialized_start = 8375 - _STATDESCRIBE._serialized_end = 8456 - _STATCROSSTAB._serialized_start = 8458 - _STATCROSSTAB._serialized_end = 8559 - _STATCOV._serialized_start = 8561 - _STATCOV._serialized_end = 8657 - _STATCORR._serialized_start = 8660 - _STATCORR._serialized_end = 8797 - _STATAPPROXQUANTILE._serialized_start = 8800 - _STATAPPROXQUANTILE._serialized_end = 8964 - _STATFREQITEMS._serialized_start = 8966 - _STATFREQITEMS._serialized_end = 9091 - _STATSAMPLEBY._serialized_start = 9094 - _STATSAMPLEBY._serialized_end = 9403 - _STATSAMPLEBY_FRACTION._serialized_start = 9295 - _STATSAMPLEBY_FRACTION._serialized_end = 9394 - _NAFILL._serialized_start = 9406 - _NAFILL._serialized_end = 9540 - _NADROP._serialized_start = 9543 - _NADROP._serialized_end = 9677 - _NAREPLACE._serialized_start = 9680 - _NAREPLACE._serialized_end = 9976 - _NAREPLACE_REPLACEMENT._serialized_start = 9835 - _NAREPLACE_REPLACEMENT._serialized_end = 9976 - _TODF._serialized_start = 9978 - _TODF._serialized_end = 10066 - _WITHCOLUMNSRENAMED._serialized_start = 10069 - _WITHCOLUMNSRENAMED._serialized_end = 10308 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10241 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10308 - _WITHCOLUMNS._serialized_start = 10310 - _WITHCOLUMNS._serialized_end = 10429 - _WITHWATERMARK._serialized_start = 10432 - _WITHWATERMARK._serialized_end = 10566 - _HINT._serialized_start = 10569 - _HINT._serialized_end = 10701 - _UNPIVOT._serialized_start = 10704 - _UNPIVOT._serialized_end = 11031 - _UNPIVOT_VALUES._serialized_start = 10961 - _UNPIVOT_VALUES._serialized_end = 11020 - _TOSCHEMA._serialized_start = 11033 - _TOSCHEMA._serialized_end = 11139 - _REPARTITIONBYEXPRESSION._serialized_start = 11142 - _REPARTITIONBYEXPRESSION._serialized_end = 11345 - _MAPPARTITIONS._serialized_start = 11348 - _MAPPARTITIONS._serialized_end = 11529 - _GROUPMAP._serialized_start = 11532 - _GROUPMAP._serialized_end = 12167 - _COGROUPMAP._serialized_start = 12170 - _COGROUPMAP._serialized_end = 12696 - _APPLYINPANDASWITHSTATE._serialized_start = 12699 - _APPLYINPANDASWITHSTATE._serialized_end = 13056 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13059 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13303 - _PYTHONUDTF._serialized_start = 13306 - _PYTHONUDTF._serialized_end = 13483 - _COLLECTMETRICS._serialized_start = 13486 - _COLLECTMETRICS._serialized_end = 13622 - _PARSE._serialized_start = 13625 - _PARSE._serialized_end = 14013 + _CACHEDLOCALRELATION._serialized_end = 7253 + _CACHEDREMOTERELATION._serialized_start = 7255 + _CACHEDREMOTERELATION._serialized_end = 7310 + _SAMPLE._serialized_start = 7313 + _SAMPLE._serialized_end = 7586 + _RANGE._serialized_start = 7589 + _RANGE._serialized_end = 7734 + _SUBQUERYALIAS._serialized_start = 7736 + _SUBQUERYALIAS._serialized_end = 7850 + _REPARTITION._serialized_start = 7853 + _REPARTITION._serialized_end = 7995 + _SHOWSTRING._serialized_start = 7998 + _SHOWSTRING._serialized_end = 8140 + _HTMLSTRING._serialized_start = 8142 + _HTMLSTRING._serialized_end = 8256 + _STATSUMMARY._serialized_start = 8258 + _STATSUMMARY._serialized_end = 8350 + _STATDESCRIBE._serialized_start = 8352 + _STATDESCRIBE._serialized_end = 8433 + _STATCROSSTAB._serialized_start = 8435 + _STATCROSSTAB._serialized_end = 8536 + _STATCOV._serialized_start = 8538 + _STATCOV._serialized_end = 8634 + _STATCORR._serialized_start = 8637 + _STATCORR._serialized_end = 8774 + _STATAPPROXQUANTILE._serialized_start = 8777 + _STATAPPROXQUANTILE._serialized_end = 8941 + _STATFREQITEMS._serialized_start = 8943 + _STATFREQITEMS._serialized_end = 9068 + _STATSAMPLEBY._serialized_start = 9071 + _STATSAMPLEBY._serialized_end = 9380 + _STATSAMPLEBY_FRACTION._serialized_start = 9272 + _STATSAMPLEBY_FRACTION._serialized_end = 9371 + _NAFILL._serialized_start = 9383 + _NAFILL._serialized_end = 9517 + _NADROP._serialized_start = 9520 + _NADROP._serialized_end = 9654 + _NAREPLACE._serialized_start = 9657 + _NAREPLACE._serialized_end = 9953 + _NAREPLACE_REPLACEMENT._serialized_start = 9812 + _NAREPLACE_REPLACEMENT._serialized_end = 9953 + _TODF._serialized_start = 9955 + _TODF._serialized_end = 10043 + _WITHCOLUMNSRENAMED._serialized_start = 10046 + _WITHCOLUMNSRENAMED._serialized_end = 10285 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10218 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10285 + _WITHCOLUMNS._serialized_start = 10287 + _WITHCOLUMNS._serialized_end = 10406 + _WITHWATERMARK._serialized_start = 10409 + _WITHWATERMARK._serialized_end = 10543 + _HINT._serialized_start = 10546 + _HINT._serialized_end = 10678 + _UNPIVOT._serialized_start = 10681 + _UNPIVOT._serialized_end = 11008 + _UNPIVOT_VALUES._serialized_start = 10938 + _UNPIVOT_VALUES._serialized_end = 10997 + _TOSCHEMA._serialized_start = 11010 + _TOSCHEMA._serialized_end = 11116 + _REPARTITIONBYEXPRESSION._serialized_start = 11119 + _REPARTITIONBYEXPRESSION._serialized_end = 11322 + _MAPPARTITIONS._serialized_start = 11325 + _MAPPARTITIONS._serialized_end = 11506 + _GROUPMAP._serialized_start = 11509 + _GROUPMAP._serialized_end = 12144 + _COGROUPMAP._serialized_start = 12147 + _COGROUPMAP._serialized_end = 12673 + _APPLYINPANDASWITHSTATE._serialized_start = 12676 + _APPLYINPANDASWITHSTATE._serialized_end = 13033 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13036 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13280 + _PYTHONUDTF._serialized_start = 13283 + _PYTHONUDTF._serialized_end = 13460 + _COLLECTMETRICS._serialized_start = 13463 + _COLLECTMETRICS._serialized_end = 13599 + _PARSE._serialized_start = 13602 + _PARSE._serialized_end = 13990 _PARSE_OPTIONSENTRY._serialized_start = 3987 _PARSE_OPTIONSENTRY._serialized_end = 4045 - _PARSE_PARSEFORMAT._serialized_start = 13914 - _PARSE_PARSEFORMAT._serialized_end = 14002 + _PARSE_PARSEFORMAT._serialized_start = 13891 + _PARSE_PARSEFORMAT._serialized_end = 13979 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 9cadd4acc5224..007b92ef5f42d 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1647,28 +1647,15 @@ class CachedLocalRelation(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - USERID_FIELD_NUMBER: builtins.int - SESSIONID_FIELD_NUMBER: builtins.int HASH_FIELD_NUMBER: builtins.int - userId: builtins.str - """(Required) An identifier of the user which created the local relation""" - sessionId: builtins.str - """(Required) An identifier of the Spark SQL session in which the user created the local relation.""" hash: builtins.str """(Required) A sha-256 hash of the serialized local relation in proto, see LocalRelation.""" def __init__( self, *, - userId: builtins.str = ..., - sessionId: builtins.str = ..., hash: builtins.str = ..., ) -> None: ... - def ClearField( - self, - field_name: typing_extensions.Literal[ - "hash", b"hash", "sessionId", b"sessionId", "userId", b"userId" - ], - ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["hash", b"hash"]) -> None: ... global___CachedLocalRelation = CachedLocalRelation diff --git a/python/pyspark/sql/connect/protobuf/functions.py b/python/pyspark/sql/connect/protobuf/functions.py index 56119f4bc4eb9..c8e12640b3136 100644 --- a/python/pyspark/sql/connect/protobuf/functions.py +++ b/python/pyspark/sql/connect/protobuf/functions.py @@ -144,7 +144,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.protobuf.functions tests") - .remote("local[2]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index cfcbcede34873..7cfdf9910d7e0 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -830,6 +830,7 @@ def overwritePartitions(self) -> None: def _test() -> None: import sys + import os import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.readwriter @@ -838,7 +839,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.readwriter tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 1307c8bdd84e1..10d599ca397b9 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -910,6 +910,7 @@ def session_id(self) -> str: def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -917,7 +918,9 @@ def _test() -> None: globs = pyspark.sql.connect.session.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.session tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.session tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) # Uses PySpark session to test builder. diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 021d27e939de8..7d968b175f281 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -276,7 +276,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.streaming.query tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 89097fcf43a01..afee833fda4e9 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -586,6 +586,7 @@ def toTable( def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -595,7 +596,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.streaming.readwriter tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py index ad082c6e265db..922a641c2428c 100644 --- a/python/pyspark/sql/connect/window.py +++ b/python/pyspark/sql/connect/window.py @@ -235,6 +235,7 @@ def rangeBetween(start: int, end: int) -> "WindowSpec": def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -242,7 +243,9 @@ def _test() -> None: globs = pyspark.sql.connect.window.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.window tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.window tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 30ed73d3c47b0..afa979dab019e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -529,6 +529,7 @@ def writeStream(self) -> DataStreamWriter: Examples -------- + >>> import time >>> import tempfile >>> df = spark.readStream.format("rate").load() >>> type(df.writeStream) @@ -536,9 +537,10 @@ def writeStream(self) -> DataStreamWriter: >>> with tempfile.TemporaryDirectory() as d: ... # Create a table with Rate source. - ... df.writeStream.toTable( + ... query = df.writeStream.toTable( ... "my_table", checkpointLocation=d) - <...streaming.query.StreamingQuery object at 0x...> + ... time.sleep(3) + ... query.stop() """ return DataStreamWriter(self) @@ -942,7 +944,11 @@ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = age | 16 name | Bob """ + print(self._show_string(n, truncate, vertical)) + def _show_string( + self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False + ) -> str: if not isinstance(n, int) or isinstance(n, bool): raise PySparkTypeError( error_class="NOT_INT", @@ -956,7 +962,7 @@ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = ) if isinstance(truncate, bool) and truncate: - print(self._jdf.showString(n, 20, vertical)) + return self._jdf.showString(n, 20, vertical) else: try: int_truncate = int(truncate) @@ -969,7 +975,7 @@ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = }, ) - print(self._jdf.showString(n, int_truncate, vertical)) + return self._jdf.showString(n, int_truncate, vertical) def __repr__(self) -> str: if not self._support_repr_html and self.sparkSession._jconf.isReplEagerEvalEnabled(): @@ -1485,7 +1491,7 @@ def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: self.rdd.foreachPartition(f) # type: ignore[arg-type] def cache(self) -> "DataFrame": - """Persists the :class:`DataFrame` with the default storage level (`MEMORY_AND_DISK`). + """Persists the :class:`DataFrame` with the default storage level (`MEMORY_AND_DISK_DESER`). .. versionadded:: 1.3.0 @@ -1494,7 +1500,7 @@ def cache(self) -> "DataFrame": Notes ----- - The default storage level has changed to `MEMORY_AND_DISK` to match Scala in 2.0. + The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. Returns ------- @@ -1507,7 +1513,7 @@ def cache(self) -> "DataFrame": >>> df.cache() DataFrame[id: bigint] - >>> df.explain() + >>> df.explain() # doctest: +SKIP == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- InMemoryTableScan ... @@ -1550,7 +1556,7 @@ def persist( >>> df.persist() DataFrame[id: bigint] - >>> df.explain() + >>> df.explain() # doctest: +SKIP == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- InMemoryTableScan ... @@ -3881,8 +3887,8 @@ def union(self, other: "DataFrame") -> "DataFrame": >>> df2 = spark.createDataFrame([(3, "Charlie"), (4, "Dave")], ["id", "name"]) >>> df1 = df1.withColumn("age", lit(30)) >>> df2 = df2.withColumn("age", lit(40)) - >>> df3 = df1.union(df2) - >>> df3.show() + >>> df3 = df1.union(df2) # doctest: +SKIP + >>> df3.show() # doctest: +SKIP +-----+-------+---+ | name| id|age| +-----+-------+---+ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 06cb3063d1b16..7e1a8faf00178 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7973,7 +7973,7 @@ def to_unix_timestamp( >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_unix_timestamp(df.e).alias('r')).collect() + >>> df.select(to_unix_timestamp(df.e).alias('r')).collect() # doctest: +SKIP [Row(r=None)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -8084,7 +8084,7 @@ def current_database() -> Column: Examples -------- - >>> spark.range(1).select(current_database()).show() + >>> spark.range(1).select(current_database()).show() # doctest: +SKIP +------------------+ |current_database()| +------------------+ @@ -8103,7 +8103,7 @@ def current_schema() -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select(sf.current_schema()).show() + >>> spark.range(1).select(sf.current_schema()).show() # doctest: +SKIP +------------------+ |current_database()| +------------------+ diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 8664c4df73ed8..3643cafbb3baf 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -613,6 +613,7 @@ def _create_from_pandas_with_arrow( # Slice the DataFrame to be batched step = self._jconf.arrowMaxRecordsPerBatch() + step = step if step > 0 else len(pdf) pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step)) # Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index bc26fdede2888..710fc8a9a370a 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -60,11 +60,10 @@ def mapInPandas( schema : :class:`pyspark.sql.types.DataType` or str the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - barrier : bool, optional, default True + barrier : bool, optional, default False Use barrier mode execution. - .. versionchanged: 3.5.0 - Added ``barrier`` argument. + .. versionadded: 3.5.0 Examples -------- @@ -139,11 +138,10 @@ def mapInArrow( schema : :class:`pyspark.sql.types.DataType` or str the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - barrier : bool, optional, default True + barrier : bool, optional, default False Use barrier mode execution. - .. versionchanged: 3.5.0 - Added ``barrier`` argument. + .. versionadded: 3.5.0 Examples -------- diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index b02a003e632cb..9b7c40ddb24e6 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -898,7 +898,7 @@ def convert_timestamp(value: Any) -> Any: return None else: if isinstance(value, datetime.datetime) and value.tzinfo is not None: - ts = pd.Timstamp(value) + ts = pd.Timestamp(value) else: ts = pd.Timestamp(value).tz_localize(timezone) return ts.to_pydatetime() diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 16f40396490c7..3a0f30872dc8c 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -107,7 +107,9 @@ def onQueryProgress(self, event: "QueryProgressEvent") -> None: """ pass - @abstractmethod + # NOTE: Do not mark this as abstract method, since we released this abstract class without + # this method in prior version and marking this as abstract method would break existing + # implementations. def onQueryIdle(self, event: "QueryIdleEvent") -> None: """ Called when the query is idle and waiting for new data to process. diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 8bf01b3ebd983..1d375de04b49b 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -18,7 +18,7 @@ import json from typing import Tuple, Optional -from pyspark.sql.types import DateType, Row, StructType +from pyspark.sql.types import Row, StructType, TimestampType from pyspark.sql.utils import has_numpy from pyspark.errors import PySparkTypeError, PySparkValueError @@ -195,7 +195,7 @@ def setTimeoutDuration(self, durationMs: int) -> None: error_class="VALUE_NOT_POSITIVE", message_parameters={ "arg_name": "durationMs", - "arg_type": type(durationMs).__name__, + "arg_value": type(durationMs).__name__, }, ) self._timeout_timestamp = durationMs + self._batch_processing_time_ms @@ -214,14 +214,14 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: ) if isinstance(timestampMs, datetime.datetime): - timestampMs = DateType().toInternal(timestampMs) + timestampMs = TimestampType().toInternal(timestampMs) / 1000 if timestampMs <= 0: raise PySparkValueError( error_class="VALUE_NOT_POSITIVE", message_parameters={ "arg_name": "timestampMs", - "arg_type": type(timestampMs).__name__, + "arg_value": type(timestampMs).__name__, }, ) diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index d45230e926b16..cf3eea0b55607 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -146,6 +146,7 @@ def test_add_file(self): ) +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod def root(cls): @@ -389,6 +390,7 @@ def test_cache_artifact(self): self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True) +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires local-cluster") class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod def conf(cls): @@ -403,7 +405,7 @@ def root(cls): @classmethod def master(cls): - return "local-cluster[2,2,512]" + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local-cluster[2,2,512]") if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 98f68767b8bca..cf43fb16df7a7 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -17,14 +17,20 @@ import unittest import uuid -from typing import Optional +from collections.abc import Generator +from typing import Optional, Any + +import grpc from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder import pyspark.sql.connect.proto as proto from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.sql.connect.client.core import Retrying -from pyspark.sql.connect.client.reattach import RetryException +from pyspark.sql.connect.client.reattach import ( + RetryException, + ExecutePlanResponseReattachableIterator, +) if should_test_connect: import pandas as pd @@ -120,6 +126,191 @@ def test_channel_builder_with_session(self): self.assertEqual(client._session_id, chan.session_id) +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class SparkConnectClientReattachTestCase(unittest.TestCase): + def setUp(self) -> None: + self.request = proto.ExecutePlanRequest() + self.policy = { + "max_retries": 3, + "backoff_multiplier": 4.0, + "initial_backoff": 10, + "max_backoff": 10, + "jitter": 10, + "min_jitter_threshold": 10, + } + self.response = proto.ExecutePlanResponse( + response_id="1", + ) + self.finished = proto.ExecutePlanResponse( + result_complete=proto.ExecutePlanResponse.ResultComplete(), + response_id="2", + ) + + def _stub_with(self, execute=None, attach=None): + return MockSparkConnectStub( + execute_ops=ResponseGenerator(execute) if execute is not None else None, + attach_ops=ResponseGenerator(attach) if attach is not None else None, + ) + + def assertEventually(self, callable, timeout_ms=1000): + """Helper method that will continuously evaluate the callable to not raise an + exception.""" + import time + + limit = time.monotonic_ns() + timeout_ms * 1000 * 1000 + while time.monotonic_ns() < limit: + try: + callable() + break + except Exception: + time.sleep(0.1) + callable() + + def test_basic_flow(self): + stub = self._stub_with([self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check_all(): + self.assertEqual(0, stub.attach_calls) + self.assertEqual(1, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check_all, timeout_ms=1000) + + def test_fail_during_execute(self): + def fatal(): + raise TestException("Fatal") + + stub = self._stub_with([self.response, fatal]) + with self.assertRaises(TestException): + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check(): + self.assertEqual(0, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.release_until_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check, timeout_ms=1000) + + def test_fail_and_retry_during_execute(self): + def non_fatal(): + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + + stub = self._stub_with( + [self.response, non_fatal], [self.response, self.response, self.finished] + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check(): + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(3, stub.release_until_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check, timeout_ms=1000) + + def test_fail_and_retry_during_reattach(self): + count = 0 + + def non_fatal(): + nonlocal count + if count < 2: + count += 1 + raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE) + else: + return proto.ExecutePlanResponse() + + stub = self._stub_with( + [self.response, non_fatal], [self.response, non_fatal, self.response, self.finished] + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + for b in ite: + pass + + def check(): + self.assertEqual(2, stub.attach_calls) + self.assertEqual(3, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.execute_calls) + + self.assertEventually(check, timeout_ms=1000) + + +class TestException(grpc.RpcError, grpc.Call): + """Exception mock to test retryable exceptions.""" + + def __init__(self, msg, code=grpc.StatusCode.INTERNAL): + self.msg = msg + self._code = code + + def code(self): + return self._code + + def __str__(self): + return self.msg + + def trailing_metadata(self): + return () + + +class ResponseGenerator(Generator): + """This class is used to generate values that are returned by the streaming + iterator of the GRPC stub.""" + + def __init__(self, funs): + self._funs = funs + self._iterator = iter(self._funs) + + def send(self, value: Any) -> proto.ExecutePlanResponse: + val = next(self._iterator) + if callable(val): + return val() + else: + return val + + def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: + super().throw(type, value, traceback) + + def close(self) -> None: + return super().close() + + +class MockSparkConnectStub: + """Simple mock class for the GRPC stub used by the re-attachable execution.""" + + def __init__(self, execute_ops=None, attach_ops=None): + self._execute_ops = execute_ops + self._attach_ops = attach_ops + # Call counters + self.execute_calls = 0 + self.release_calls = 0 + self.release_until_calls = 0 + self.attach_calls = 0 + + def ExecutePlan(self, *args, **kwargs): + self.execute_calls += 1 + return self._execute_ops + + def ReattachExecute(self, *args, **kwargs): + self.attach_calls += 1 + return self._attach_ops + + def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): + if req.HasField("release_all"): + self.release_calls += 1 + elif req.HasField("release_until"): + print("increment") + self.release_until_calls += 1 + + class MockService: # Simplest mock of the SparkConnectService. # If this needs more complex logic, it needs to be replaced with Python mocking. diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 5069a76cfdb73..99da04315f0ba 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import unittest import time +import os +import unittest import pyspark.cloudpickle from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin @@ -25,7 +25,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase -class TestListener(StreamingQueryListener): +class TestListenerSpark(StreamingQueryListener): def onQueryStarted(self, event): e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame(data=[(e,)]) @@ -46,47 +46,55 @@ def onQueryTerminated(self, event): class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_listener_events(self): - test_listener = TestListener() + test_listener = TestListenerSpark() try: - self.spark.streams.addListener(test_listener) - - # This ensures the read socket on the server won't crash (i.e. because of timeout) - # when there hasn't been a new event for a long time - time.sleep(30) - - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - df_observe = df.observe("my_event", count(lit(1)).alias("rc")) - df_stateful = df_observe.groupBy().count() # make query stateful - q = ( - df_stateful.writeStream.format("noop") - .queryName("test") - .outputMode("complete") - .start() - ) - - self.assertTrue(q.isActive) - time.sleep(10) - self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran - q.stop() - self.assertFalse(q.isActive) - - start_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_start_events").collect()[0][0] - ) - - progress_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_progress_events").collect()[0][0] - ) - - terminated_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_terminated_events").collect()[0][0] - ) - - self.check_start_event(start_event) - self.check_progress_event(progress_event) - self.check_terminated_event(terminated_event) + with self.table( + "listener_start_events", + "listener_progress_events", + "listener_terminated_events", + ): + self.spark.streams.addListener(test_listener) + + # This ensures the read socket on the server won't crash (i.e. because of timeout) + # when there hasn't been a new event for a long time + time.sleep(30) + + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) + + self.assertTrue(q.isActive) + time.sleep(10) + self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran + q.stop() + self.assertFalse(q.isActive) + + start_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_start_events").collect()[0][0] + ) + + progress_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_progress_events").collect()[0][0] + ) + + terminated_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_terminated_events").collect()[0][0] + ) + + self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) finally: self.spark.streams.removeListener(test_listener) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py index 6fe2b89408014..6b23c15775fe6 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from pyspark.sql.tests.streaming.test_streaming import StreamingTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 2904eb42587e8..d1f6994edba7b 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -80,6 +80,7 @@ from pyspark.sql.connect.client.core import Retrying, SparkConnectClient +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): """Parent test fixture class for all Spark Connect related test cases.""" @@ -1241,6 +1242,13 @@ def test_sql_with_named_args(self): df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7}) self.assert_eq(df.toPandas(), df2.toPandas()) + def test_namedargs_with_global_limit(self): + sqlText = """SELECT * FROM VALUES (TIMESTAMP('2022-12-25 10:30:00'), 1) as tab(date, val) + where val = :val""" + df = self.connect.sql(sqlText, args={"val": 1}) + df2 = self.spark.sql(sqlText, args={"val": 1}) + self.assert_eq(df.toPandas(), df2.toPandas()) + def test_sql_with_pos_args(self): df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7]) df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7]) @@ -3250,12 +3258,15 @@ def test_df_caache(self): self.assertTrue(df.is_cached) +@unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Session creation different from local mode" +) class SparkConnectSessionTests(ReusedConnectTestCase): def setUp(self) -> None: self.spark = ( PySparkSession.builder.config(conf=self.conf()) .appName(self.__class__.__name__) - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) @@ -3347,6 +3358,7 @@ def test_can_create_multiple_sessions_to_different_remotes(self): self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e)) +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: self.spark = ( diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index a5d330fe1a7e9..dc101e98e01d4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -36,6 +36,7 @@ from pyspark.sql.connect.dataframe import DataFrame as CDF +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index c39fb6be24cdd..88ef37511a666 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -430,7 +430,7 @@ def test_sample(self): self.assertEqual(plan.root.sample.lower_bound, 0.0) self.assertEqual(plan.root.sample.upper_bound, 0.3) self.assertEqual(plan.root.sample.with_replacement, False) - self.assertEqual(plan.root.sample.HasField("seed"), False) + self.assertEqual(plan.root.sample.HasField("seed"), True) self.assertEqual(plan.root.sample.deterministic_order, False) plan = ( diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index a92ef971cd216..abcf839f0fc56 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -16,6 +16,7 @@ # import unittest +import sys from distutils.version import LooseVersion import pandas as pd @@ -136,6 +137,10 @@ def test_createDataFrame_nested_timestamp(self): def test_toPandas_nested_timestamp(self): self.check_toPandas_nested_timestamp(True) + @unittest.skipIf(sys.version_info < (3, 9), "zoneinfo is available from Python 3.9+") + def test_toPandas_timestmap_tzinfo(self): + self.check_toPandas_timestmap_tzinfo(True) + def test_createDataFrame_udt(self): self.check_createDataFrame_udt(True) diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py index c950ca2e17c3c..a508fe1059ed8 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import unittest from pyspark.sql.tests.pandas.test_pandas_udf_scalar import ScalarPandasUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase @@ -31,6 +32,9 @@ def test_vectorized_udf_empty_partition(self): def test_vectorized_udf_struct_with_empty_partition(self): super().test_vectorized_udf_struct_with_empty_partition() + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_vectorized_udf_exception(self): self.check_vectorized_udf_exception() diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 1222b1bb5b44f..ebf7692a20cd8 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import unittest from pyspark.testing.connectutils import should_test_connect if should_test_connect: @@ -57,6 +59,78 @@ def eval(self, a: int): ): TestUDTF(lit(1)).collect() + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_init_with_additional_args(self): + super(UDTFParityTests, self).test_udtf_init_with_additional_args() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_with_wrong_num_input(self): + super(UDTFParityTests, self).test_udtf_with_wrong_num_input() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_array_output_type_casting(self): + super(UDTFParityTests, self).test_array_output_type_casting() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_map_output_type_casting(self): + super(UDTFParityTests, self).test_map_output_type_casting() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_numeric_output_type_casting(self): + super(UDTFParityTests, self).test_numeric_output_type_casting() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_numeric_output_type_casting(self): + super(UDTFParityTests, self).test_numeric_output_type_casting() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_numeric_string_output_type_casting(self): + super(UDTFParityTests, self).test_numeric_string_output_type_casting() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_string_output_type_casting(self): + super(UDTFParityTests, self).test_string_output_type_casting() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_string_output_type_casting(self): + super(UDTFParityTests, self).test_string_output_type_casting() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_struct_output_type_casting_dict(self): + super(UDTFParityTests, self).test_struct_output_type_casting_dict() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_init_with_additional_args(self): + super(UDTFParityTests, self).test_udtf_init_with_additional_args() + + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_with_wrong_num_input(self): + super(UDTFParityTests, self).test_udtf_with_wrong_num_input() + class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): @classmethod diff --git a/python/pyspark/sql/tests/connect/test_utils.py b/python/pyspark/sql/tests/connect/test_utils.py index 917cb58057f7f..19fa9cd93f321 100644 --- a/python/pyspark/sql/tests/connect/test_utils.py +++ b/python/pyspark/sql/tests/connect/test_utils.py @@ -14,13 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import unittest from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.sql.tests.test_utils import UtilsTestsMixin class ConnectUtilsTests(ReusedConnectTestCase, UtilsTestsMixin): - pass + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_assert_approx_equal_decimaltype_custom_rtol_pass(self): + super(ConnectUtilsTests, self).test_assert_approx_equal_decimaltype_custom_rtol_pass() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index b867156e71a5d..948ef4a53f2cf 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -166,7 +166,7 @@ def check_apply_in_pandas_not_returning_pandas_dataframe(self): fn=lambda lft, rgt: lft.size + rgt.size, error_class=PythonException, error_message_regex="Return type of the user-defined function " - "should be pandas.DataFrame, but is int64.", + "should be pandas.DataFrame, but is int", ) def test_apply_in_pandas_returning_column_names(self): @@ -445,6 +445,41 @@ def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: actual = df.orderBy("id", "day").take(days) self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)]) + def test_with_local_data(self): + df1 = self.spark.createDataFrame( + [(1, 1.0, "a"), (2, 2.0, "b"), (1, 3.0, "c"), (2, 4.0, "d")], ("id", "v1", "v2") + ) + df2 = self.spark.createDataFrame([(1, "x"), (2, "y"), (1, "z")], ("id", "v3")) + + def summarize(left, right): + return pd.DataFrame( + { + "left_rows": [len(left)], + "left_columns": [len(left.columns)], + "right_rows": [len(right)], + "right_columns": [len(right.columns)], + } + ) + + df = ( + df1.groupby("id") + .cogroup(df2.groupby("id")) + .applyInPandas( + summarize, + schema="left_rows long, left_columns long, right_rows long, right_columns long", + ) + ) + + self.assertEqual( + df._show_string(), + "+---------+------------+----------+-------------+\n" + "|left_rows|left_columns|right_rows|right_columns|\n" + "+---------+------------+----------+-------------+\n" + "| 2| 3| 2| 2|\n" + "| 2| 3| 1| 2|\n" + "+---------+------------+----------+-------------+\n", + ) + @staticmethod def _test_with_key(left, right, isLeft): def right_assign_key(key, lft, rgt): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index fb2f9214c5d8f..4b2be2bcf8442 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -110,7 +110,7 @@ def func(iterator): df = ( self.spark.range(10, numPartitions=3) .select(col("id").cast("string").alias("str")) - .withColumn("bin", encode(col("str"), "utf8")) + .withColumn("bin", encode(col("str"), "utf-8")) ) actual = df.mapInPandas(func, "str string, bin binary").collect() expected = df.collect() @@ -151,14 +151,14 @@ def bad_iter_elem(_): with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be iterator of pandas.DataFrame, " - "but is int.", + "but is int", ): (self.spark.range(10, numPartitions=3).mapInPandas(no_iter, "a int").count()) with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be iterator of pandas.DataFrame, " - "but is iterator of int.", + "but is iterator of int", ): (self.spark.range(10, numPartitions=3).mapInPandas(bad_iter_elem, "a int").count()) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 34cd9c2358195..b54e5608f3d09 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest import datetime from typing import cast @@ -262,6 +262,9 @@ def foo(x): .collect, ) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_pandas_udf_detect_unsafe_type_conversion(self): import pandas as pd import numpy as np @@ -285,6 +288,9 @@ def udf(column): with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}): df.select(["A"]).withColumn("udf", udf("A")).collect() + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_pandas_udf_arrow_overflow(self): import pandas as pd diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 0eea86dc73756..69a5a2b90986e 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -264,36 +264,37 @@ def test_stream_await_termination(self): shutil.rmtree(tmpPath) def test_stream_exception(self): - sdf = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - sq = sdf.writeStream.format("memory").queryName("query_explain").start() - try: - sq.processAllAvailable() - self.assertEqual(sq.exception(), None) - finally: - sq.stop() - - from pyspark.sql.functions import col, udf - from pyspark.errors import StreamingQueryException - - bad_udf = udf(lambda x: 1 / 0) - sq = ( - sdf.select(bad_udf(col("value"))) - .writeStream.format("memory") - .queryName("this_query") - .start() - ) - try: - # Process some data to fail the query - sq.processAllAvailable() - self.fail("bad udf should fail the query") - except StreamingQueryException as e: - # This is expected - self._assert_exception_tree_contains_msg(e, "ZeroDivisionError") - finally: - exception = sq.exception() - sq.stop() - self.assertIsInstance(exception, StreamingQueryException) - self._assert_exception_tree_contains_msg(exception, "ZeroDivisionError") + with self.sql_conf({"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": True}): + sdf = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + sq = sdf.writeStream.format("memory").queryName("query_explain").start() + try: + sq.processAllAvailable() + self.assertEqual(sq.exception(), None) + finally: + sq.stop() + + from pyspark.sql.functions import col, udf + from pyspark.errors import StreamingQueryException + + bad_udf = udf(lambda x: 1 / 0) + sq = ( + sdf.select(bad_udf(col("value"))) + .writeStream.format("memory") + .queryName("this_query") + .start() + ) + try: + # Process some data to fail the query + sq.processAllAvailable() + self.fail("bad udf should fail the query") + except StreamingQueryException as e: + # This is expected + self._assert_exception_tree_contains_msg(e, "ZeroDivisionError") + finally: + exception = sq.exception() + sq.stop() + self.assertIsInstance(exception, StreamingQueryException) + self._assert_exception_tree_contains_msg(exception, "ZeroDivisionError") def _assert_exception_tree_contains_msg(self, exception, msg): if isinstance(exception, SparkConnectException): diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 87d0dae00d8bd..05c1ec71675c2 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -251,7 +251,23 @@ def test_listener_events(self): progress_event = None terminated_event = None - class TestListener(StreamingQueryListener): + # V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`, + # `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5. + class TestListenerV1(StreamingQueryListener): + def onQueryStarted(self, event): + nonlocal start_event + start_event = event + + def onQueryProgress(self, event): + nonlocal progress_event + progress_event = event + + def onQueryTerminated(self, event): + nonlocal terminated_event + terminated_event = event + + # V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+. + class TestListenerV2(StreamingQueryListener): def onQueryStarted(self, event): nonlocal start_event start_event = event @@ -267,48 +283,71 @@ def onQueryTerminated(self, event): nonlocal terminated_event terminated_event = event - test_listener = TestListener() + def verify(test_listener): + nonlocal start_event + nonlocal progress_event + nonlocal terminated_event - try: - self.spark.streams.addListener(test_listener) + start_event = None + progress_event = None + terminated_event = None - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + try: + self.spark.streams.addListener(test_listener) - # check successful stateful query - df_stateful = df.groupBy().count() # make query stateful - q = ( - df_stateful.writeStream.format("noop") - .queryName("test") - .outputMode("complete") - .start() - ) - self.assertTrue(q.isActive) - time.sleep(10) - q.stop() + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - # Make sure all events are empty - self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + # check successful stateful query + df_stateful = df.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) + self.assertTrue(q.isActive) + time.sleep(10) + q.stop() - self.check_start_event(start_event) - self.check_progress_event(progress_event) - self.check_terminated_event(terminated_event) + # Make sure all events are empty + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() - # Check query terminated with exception - from pyspark.sql.functions import col, udf + self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) - bad_udf = udf(lambda x: 1 / 0) - q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() - time.sleep(5) - q.stop() - self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() - self.check_terminated_event(terminated_event, "ZeroDivisionError") + # Check query terminated with exception + from pyspark.sql.functions import col, udf - finally: - self.spark.streams.removeListener(test_listener) + bad_udf = udf(lambda x: 1 / 0) + q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() + time.sleep(5) + q.stop() + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + self.check_terminated_event(terminated_event, "ZeroDivisionError") + + finally: + self.spark.streams.removeListener(test_listener) + + verify(TestListenerV1()) + verify(TestListenerV2()) def test_remove_listener(self): # SPARK-38804: Test StreamingQueryManager.removeListener - class TestListener(StreamingQueryListener): + # V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`, + # `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5. + class TestListenerV1(StreamingQueryListener): + def onQueryStarted(self, event): + pass + + def onQueryProgress(self, event): + pass + + def onQueryTerminated(self, event): + pass + + # V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+. + class TestListenerV2(StreamingQueryListener): def onQueryStarted(self, event): pass @@ -321,13 +360,15 @@ def onQueryIdle(self, event): def onQueryTerminated(self, event): pass - test_listener = TestListener() + def verify(test_listener): + num_listeners = len(self.spark.streams._jsqm.listListeners()) + self.spark.streams.addListener(test_listener) + self.assertEqual(num_listeners + 1, len(self.spark.streams._jsqm.listListeners())) + self.spark.streams.removeListener(test_listener) + self.assertEqual(num_listeners, len(self.spark.streams._jsqm.listListeners())) - num_listeners = len(self.spark.streams._jsqm.listListeners()) - self.spark.streams.addListener(test_listener) - self.assertEqual(num_listeners + 1, len(self.spark.streams._jsqm.listListeners())) - self.spark.streams.removeListener(test_listener) - self.assertEqual(num_listeners, len(self.spark.streams._jsqm.listListeners())) + verify(TestListenerV1()) + verify(TestListenerV2()) def test_query_started_event_fromJson(self): start_event = """ diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 73b6067373b07..a333318b777e9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -18,12 +18,14 @@ import datetime import os import threading +import calendar import time import unittest import warnings from distutils.version import LooseVersion from typing import cast from collections import namedtuple +import sys from pyspark import SparkContext, SparkConf from pyspark.sql import Row, SparkSession @@ -1090,6 +1092,34 @@ def check_createDataFrame_nested_timestamp(self, arrow_enabled): self.assertEqual(df.first(), expected) + @unittest.skipIf(sys.version_info < (3, 9), "zoneinfo is available from Python 3.9+") + def test_toPandas_timestmap_tzinfo(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_toPandas_timestmap_tzinfo(arrow_enabled) + + def check_toPandas_timestmap_tzinfo(self, arrow_enabled): + # SPARK-47202: Test timestamp with tzinfo in toPandas and createDataFrame + from zoneinfo import ZoneInfo + + ts_tzinfo = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("America/Los_Angeles")) + data = pd.DataFrame({"a": [ts_tzinfo]}) + df = self.spark.createDataFrame(data) + + with self.sql_conf( + { + "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled, + } + ): + pdf = df.toPandas() + + expected = pd.DataFrame( + # Spark unsets tzinfo and converts them to localtimes. + {"a": [datetime.datetime.fromtimestamp(calendar.timegm(ts_tzinfo.utctimetuple()))]} + ) + + assert_frame_equal(pdf, expected) + def test_toPandas_nested_timestamp(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): @@ -1238,6 +1268,16 @@ class MyInheritedTuple(MyTuple): df = self.spark.createDataFrame([MyInheritedTuple(1, 2, MyInheritedTuple(1, 2, 3))]) self.assertEqual(df.first(), Row(a=1, b=2, c=Row(a=1, b=2, c=3))) + def test_negative_and_zero_batch_size(self): + # SPARK-47068: Negative and zero value should work as unlimited batch size. + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 0}): + pdf = pd.DataFrame({"a": [123]}) + assert_frame_equal(pdf, self.spark.createDataFrame(pdf).toPandas()) + + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": -1}): + pdf = pd.DataFrame({"a": [123]}) + assert_frame_equal(pdf, self.spark.createDataFrame(pdf).toPandas()) + @unittest.skipIf( not have_pandas or not have_pyarrow, diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py index 15367743585e3..176286a809d45 100644 --- a/python/pyspark/sql/tests/test_arrow_map.py +++ b/python/pyspark/sql/tests/test_arrow_map.py @@ -104,14 +104,14 @@ def bad_iter_elem(_): with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be iterator " - "of pyarrow.RecordBatch, but is int.", + "of pyarrow.RecordBatch, but is int", ): (self.spark.range(10, numPartitions=3).mapInArrow(not_iter, "a int").count()) with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be iterator " - "of pyarrow.RecordBatch, but is iterator of int.", + "of pyarrow.RecordBatch, but is iterator of int", ): (self.spark.range(10, numPartitions=3).mapInArrow(bad_iter_elem, "a int").count()) diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index cafffdc9ae8b5..b72172a402bfc 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -486,7 +486,7 @@ def test_refresh_table(self): self.assertEqual(spark.table("my_tab").count(), 0) -class CatalogTests(ReusedSQLTestCase): +class CatalogTests(CatalogTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 33049233dee98..2f80a56e96002 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -106,6 +106,43 @@ def test_drop(self): self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"]) self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"]) + def test_drop_join(self): + left_df = self.spark.createDataFrame( + [(1, "a"), (2, "b"), (3, "c")], + ["join_key", "value1"], + ) + right_df = self.spark.createDataFrame( + [(1, "aa"), (2, "bb"), (4, "dd")], + ["join_key", "value2"], + ) + joined_df = left_df.join( + right_df, + on=left_df["join_key"] == right_df["join_key"], + how="left", + ) + + dropped_1 = joined_df.drop(left_df["join_key"]) + self.assertEqual(dropped_1.columns, ["value1", "join_key", "value2"]) + self.assertEqual( + dropped_1.sort("value1").collect(), + [ + Row(value1="a", join_key=1, value2="aa"), + Row(value1="b", join_key=2, value2="bb"), + Row(value1="c", join_key=None, value2=None), + ], + ) + + dropped_2 = joined_df.drop(right_df["join_key"]) + self.assertEqual(dropped_2.columns, ["join_key", "value1", "value2"]) + self.assertEqual( + dropped_2.sort("value1").collect(), + [ + Row(join_key=1, value1="a", value2="aa"), + Row(join_key=2, value1="b", value2="bb"), + Row(join_key=3, value1="c", value2=None), + ], + ) + def test_with_columns_renamed(self): df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) @@ -1008,6 +1045,11 @@ def test_sample(self): IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count() ) + def test_sample_with_random_seed(self): + df = self.spark.range(10000).sample(0.1) + cnts = [df.count() for i in range(10)] + self.assertEqual(1, len(set(cnts))) + def test_toDF_with_string(self): df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)]) data = [("John", 30), ("Alice", 25), ("Bob", 28)] @@ -1508,6 +1550,9 @@ def test_create_dataframe_from_pandas_with_day_time_interval(self): df = self.spark.createDataFrame(pd.DataFrame({"a": [timedelta(microseconds=123)]})) self.assertEqual(df.toPandas().a.iloc[0], timedelta(microseconds=123)) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Newline difference from the server" + ) def test_repr_behaviors(self): import re diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index 6418983b06a44..c920fa75f4b29 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os +import unittest import shutil import tempfile import uuid @@ -146,6 +147,9 @@ def test_csv_sampling_ratio(self): schema = self.spark.read.option("inferSchema", True).csv(rdd, samplingRatio=0.5).schema self.assertEqual(schema, StructType([StructField("_c0", IntegerType(), True)])) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_checking_csv_header(self): path = tempfile.mkdtemp() shutil.rmtree(path) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 6bcef51732f86..7911a82c61fcc 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -16,11 +16,12 @@ # import os +import unittest import shutil import tempfile from pyspark.errors import AnalysisException -from pyspark.sql.functions import col +from pyspark.sql.functions import col, lit from pyspark.sql.readwriter import DataFrameWriterV2 from pyspark.sql.types import StructType, StructField, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -181,6 +182,27 @@ def test_insert_into(self): df.write.mode("overwrite").insertInto("test_table", False) self.assertEqual(6, self.spark.sql("select * from test_table").count()) + def test_cached_table(self): + with self.table("test_cached_table_1"): + self.spark.range(10).withColumn( + "value_1", + lit(1), + ).write.saveAsTable("test_cached_table_1") + + with self.table("test_cached_table_2"): + self.spark.range(10).withColumnRenamed("id", "index").withColumn( + "value_2", lit(2) + ).write.saveAsTable("test_cached_table_2") + + df1 = self.spark.read.table("test_cached_table_1") + df2 = self.spark.read.table("test_cached_table_2") + df3 = self.spark.read.table("test_cached_table_1") + + join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2) + join2 = df3.join(join1, how="left", on=join1.index == df3.id) + + self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"]) + class ReadwriterV2TestsMixin: def test_api(self): @@ -224,6 +246,9 @@ def test_create(self): df.writeTo("test_table").using("parquet").create() self.assertEqual(100, self.spark.sql("select * from test_table").count()) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Known behavior change in 4.0" + ) def test_create_without_provider(self): df = self.df with self.assertRaisesRegex( @@ -231,6 +256,11 @@ def test_create_without_provider(self): ): df.writeTo("test_table").create() + def test_table_overwrite(self): + df = self.df + with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): + df.writeTo("test_table").overwrite(lit(True)) + class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 90ecfd657765d..504d945dc02a0 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -753,6 +753,9 @@ def test_cast_to_string_with_udt(self): result = df.select(F.col("point").cast("string"), F.col("pypoint").cast("string")).head() self.assertEqual(result, Row(point="(1.0, 2.0)", pypoint="[3.0, 4.0]")) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "SPARK-49787: Supported since Spark 4.0.0" + ) def test_cast_to_udt_with_udt(self): row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) @@ -812,6 +815,9 @@ def test_struct_type(self): self.assertRaises(IndexError, lambda: struct1[9]) self.assertRaises(TypeError, lambda: struct1[9.9]) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string @@ -1272,6 +1278,13 @@ def test_yearmonth_interval_type(self): schema3 = self.spark.sql("SELECT INTERVAL '8' MONTH AS interval").schema self.assertEqual(schema3.fields[0].dataType, YearMonthIntervalType(1, 1)) + def test_infer_array_element_type_with_struct(self): + # SPARK-48248: Nested array to respect legacy conf of inferArrayTypeFromFirstElement + with self.sql_conf( + {"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled": True} + ): + self.assertEqual([[1, None]], self.spark.createDataFrame([[[[1, "a"]]]]).first()[0]) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py index 136f423d0a35c..019e502ec67cf 100644 --- a/python/pyspark/sql/tests/test_udf_profiler.py +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -19,11 +19,19 @@ import unittest import os import sys +import warnings from io import StringIO +from typing import Iterator from pyspark import SparkConf from pyspark.sql import SparkSession -from pyspark.sql.functions import udf +from pyspark.sql.functions import udf, pandas_udf +from pyspark.testing.sqlutils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) from pyspark.profiler import UDFBasicProfiler @@ -101,6 +109,49 @@ def add2(x): df = self.spark.range(10) df.select(add1("id"), add2("id"), add1("id")).collect() + # Unsupported + def exec_pandas_udf_iter_to_iter(self): + import pandas as pd + + @pandas_udf("int") + def iter_to_iter(batch_ser: Iterator[pd.Series]) -> Iterator[pd.Series]: + for ser in batch_ser: + yield ser + 1 + + self.spark.range(10).select(iter_to_iter("id")).collect() + + # Unsupported + def exec_map(self): + import pandas as pd + + def map(pdfs: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: + for pdf in pdfs: + yield pdf[pdf.id == 1] + + df = self.spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0)], ("id", "v")) + df.mapInPandas(map, schema=df.schema).collect() + + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) # type: ignore + def test_unsupported(self): + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter("always") + self.exec_pandas_udf_iter_to_iter() + user_warns = [warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0]) + ) + + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter("always") + self.exec_map() + user_warns = [warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0]) + ) + if __name__ == "__main__": from pyspark.sql.tests.test_udf_profiler import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index e1b7f298d0a8b..dc25fc116a8bd 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -1611,6 +1611,18 @@ def test_list_row_unequal_schema(self): message_parameters={"error_msg": error_msg}, ) + def test_assert_data_frame_equal_not_support_streaming(self): + df1 = self.spark.readStream.format("rate").load() + df2 = self.spark.readStream.format("rate").load() + exception_thrown = False + try: + assertDataFrameEqual(df1, df2) + except PySparkAssertionError as e: + self.assertEqual(e.getErrorClass(), "UNSUPPORTED_OPERATION") + exception_thrown = True + + self.assertTrue(exception_thrown) + class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin): def test_capture_analysis_exception(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 24964c56e2e89..a2a8796957623 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1606,13 +1606,27 @@ def _infer_type( if len(obj) > 0: if infer_array_from_first_element: return ArrayType( - _infer_type(obj[0], infer_dict_as_struct, prefer_timestamp_ntz), True + _infer_type( + obj[0], + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, ) else: return ArrayType( reduce( _merge_type, - (_infer_type(v, infer_dict_as_struct, prefer_timestamp_ntz) for v in obj), + ( + _infer_type( + v, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ) + for v in obj + ), ), True, ) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 7d7784dd5226d..bdd3aba502b89 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -28,7 +28,6 @@ from py4j.java_gateway import JavaObject from pyspark import SparkContext -from pyspark.profiler import Profiler from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq from pyspark.sql.types import ( @@ -338,24 +337,23 @@ def _create_judf(self, func: Callable[..., Any]) -> JavaObject: def __call__(self, *cols: "ColumnOrName") -> Column: sc = get_active_spark_context() - profiler: Optional[Profiler] = None - memory_profiler: Optional[Profiler] = None - if sc.profiler_collector: - profiler_enabled = sc._conf.get("spark.python.profile", "false") == "true" - memory_profiler_enabled = sc._conf.get("spark.python.profile.memory", "false") == "true" + profiler_enabled = sc._conf.get("spark.python.profile", "false") == "true" + memory_profiler_enabled = sc._conf.get("spark.python.profile.memory", "false") == "true" + if profiler_enabled or memory_profiler_enabled: # Disable profiling Pandas UDFs with iterators as input/output. - if profiler_enabled or memory_profiler_enabled: - if self.evalType in [ - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - PythonEvalType.SQL_MAP_ARROW_ITER_UDF, - ]: - profiler_enabled = memory_profiler_enabled = False - warnings.warn( - "Profiling UDFs with iterators input/output is not supported.", - UserWarning, - ) + if self.evalType in [ + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + ]: + warnings.warn( + "Profiling UDFs with iterators input/output is not supported.", + UserWarning, + ) + judf = self._judf + jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column)) + return Column(jPythonUDF) # Disallow enabling two profilers at the same time. if profiler_enabled and memory_profiler_enabled: diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index ba81c7836728e..a063f27c9ea22 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -178,7 +178,7 @@ def conf(cls): @classmethod def master(cls): - return "local[4]" + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]") @classmethod def setUpClass(cls): diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index c80ffb7ee53cb..2463289d59f71 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -365,6 +365,9 @@ def assertPandasOnSparkEqual( .. versionadded:: 3.5.0 + .. deprecated:: 3.5.1 + `assertPandasOnSparkEqual` will be removed in Spark 4.0.0. + Parameters ---------- actual: pandas-on-Spark DataFrame, Series, or Index @@ -417,6 +420,10 @@ def assertPandasOnSparkEqual( >>> s2 = ps.Index([212.3, 100.0001]) >>> assertPandasOnSparkEqual(s1, s2, almost=True) # pass, ps.Index obj are almost equal """ + warnings.warn( + "`assertPandasOnSparkEqual` will be removed in Spark 4.0.0. ", + FutureWarning, + ) if actual is None and expected is None: return True elif actual is None or expected is None: diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 7dd723634e2f9..aa927ed1c4c2b 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -587,11 +587,21 @@ def assert_rows_equal(rows1: List[Row], rows2: List[Row]): assertSchemaEqual(actual.schema, expected.schema) if not isinstance(actual, list): + if actual.isStreaming: + raise PySparkAssertionError( + error_class="UNSUPPORTED_OPERATION", + message_parameters={"operation": "assertDataFrameEqual on streaming DataFrame"}, + ) actual_list = actual.collect() else: actual_list = actual if not isinstance(expected, list): + if expected.isStreaming: + raise PySparkAssertionError( + error_class="UNSUPPORTED_OPERATION", + message_parameters={"operation": "assertDataFrameEqual on streaming DataFrame"}, + ) expected_list = expected.collect() else: expected_list = expected diff --git a/python/pyspark/version.py b/python/pyspark/version.py index daccb365340b7..002d06e28ea15 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__: str = "3.5.0" +__version__: str = "3.5.4.dev0" diff --git a/python/run-tests.py b/python/run-tests.py index 19e39c822cbb4..ca8ddb5ff8635 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -60,15 +60,17 @@ def get_valid_filename(s): FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() -# Find out where the assembly jars are located. -# TODO: revisit for Scala 2.13 -for scala in ["2.12", "2.13"]: - build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) - if os.path.isdir(build_dir): - SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") - break -else: - raise RuntimeError("Cannot find assembly build directory, please build Spark first.") +SPARK_DIST_CLASSPATH = "" +if "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ: + # Find out where the assembly jars are located. + # TODO: revisit for Scala 2.13 + for scala in ["2.12", "2.13"]: + build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) + if os.path.isdir(build_dir): + SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") + break + else: + raise RuntimeError("Cannot find assembly build directory, please build Spark first.") def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_output): @@ -98,6 +100,11 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_ 'PYARROW_IGNORE_TIMEZONE': '1', }) + if "SPARK_CONNECT_TESTING_REMOTE" in os.environ: + env.update({"SPARK_CONNECT_TESTING_REMOTE": os.environ["SPARK_CONNECT_TESTING_REMOTE"]}) + if "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ: + env.update({"SPARK_SKIP_JVM_REQUIRED_TESTS": os.environ["SPARK_SKIP_CONNECT_COMPAT_TESTS"]}) + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is # recognized by the tempfile module to override the default system temp directory. tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) @@ -147,8 +154,8 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_ # this code is invoked from a thread other than the main thread. os._exit(1) duration = time.time() - start_time - # Exit on the first failure. - if retcode != 0: + # Exit on the first failure but exclude the code 5 for no test ran, see SPARK-46801. + if retcode != 0 and retcode != 5: try: with FAILURE_REPORTING_LOCK: with open(LOG_FILE, 'ab') as log_file: diff --git a/python/setup.py b/python/setup.py index b8e4c9a40e046..ddd961d0412b5 100755 --- a/python/setup.py +++ b/python/setup.py @@ -249,6 +249,7 @@ def run(self): "pyspark.sql.connect.avro", "pyspark.sql.connect.client", "pyspark.sql.connect.proto", + "pyspark.sql.connect.protobuf", "pyspark.sql.connect.streaming", "pyspark.sql.pandas", "pyspark.sql.protobuf", @@ -306,17 +307,17 @@ def run(self): # if you're updating the versions or dependencies. install_requires=["py4j==0.10.9.7"], extras_require={ - "ml": ["numpy>=1.15"], - "mllib": ["numpy>=1.15"], + "ml": ["numpy>=1.15,<2"], + "mllib": ["numpy>=1.15,<2"], "sql": [ "pandas>=%s" % _minimum_pandas_version, "pyarrow>=%s" % _minimum_pyarrow_version, - "numpy>=1.15", + "numpy>=1.15,<2", ], "pandas_on_spark": [ "pandas>=%s" % _minimum_pandas_version, "pyarrow>=%s" % _minimum_pyarrow_version, - "numpy>=1.15", + "numpy>=1.15,<2", ], "connect": [ "pandas>=%s" % _minimum_pandas_version, @@ -324,7 +325,7 @@ def run(self): "grpcio>=%s" % _minimum_grpc_version, "grpcio-status>=%s" % _minimum_grpc_version, "googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version, - "numpy>=1.15", + "numpy>=1.15,<2", ], }, python_requires=">=3.8", diff --git a/repl/pom.xml b/repl/pom.xml index 6214dc2e18555..5ef505bbc48e5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml diff --git a/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala index 39544beec4154..067f08cb67528 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.repl import java.io._ import java.nio.charset.StandardCharsets -import java.sql.Timestamp -import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.concurrent.Promise @@ -70,10 +68,9 @@ class SparkShellSuite extends SparkFunSuite { val lock = new Object def captureOutput(source: String)(line: String): Unit = lock.synchronized { - // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we - // add a timestamp to provide more diagnosis information. - val newLine = s"${new Timestamp(new Date().getTime)} - $source> $line" - log.info(newLine) + val newLine = s"$source> $line" + + logInfo(newLine) buffer += newLine if (line.startsWith("Spark context available") && line.contains("app id")) { @@ -82,7 +79,7 @@ class SparkShellSuite extends SparkFunSuite { // If we haven't found all expected answers and another expected answer comes up... if (next < expectedAnswers.size && line.contains(expectedAnswers(next))) { - log.info(s"$source> found expected output line $next: '${expectedAnswers(next)}'") + logInfo(s"$source> found expected output line $next: '${expectedAnswers(next)}'") next += 1 // If all expected answers have been found... if (next == expectedAnswers.size) { diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 61ba640291902..872b5b8380e6a 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 0c54191fb10d5..dd8c59204b5ee 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -63,6 +63,15 @@ private[spark] object Config extends Logging { .booleanConf .createWithDefault(true) + val KUBERNETES_USE_LEGACY_PVC_ACCESS_MODE = + ConfigBuilder("spark.kubernetes.legacy.useReadWriteOnceAccessMode") + .internal() + .doc("If true, use ReadWriteOnce instead of ReadWriteOncePod as persistence volume " + + "access mode.") + .version("3.4.3") + .booleanConf + .createWithDefault(true) + val KUBERNETES_DRIVER_SERVICE_IP_FAMILY_POLICY = ConfigBuilder("spark.kubernetes.driver.service.ipFamilyPolicy") .doc("K8s IP Family Policy for Driver Service") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index 78dd6ec21ed34..64f3491b861ee 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import io.fabric8.kubernetes.api.model._ import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Config.KUBERNETES_USE_LEGACY_PVC_ACCESS_MODE import org.apache.spark.deploy.k8s.Constants.{ENV_EXECUTOR_ID, SPARK_APP_ID_LABEL} private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) @@ -29,6 +30,11 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) import MountVolumesFeatureStep._ val additionalResources = ArrayBuffer.empty[HasMetadata] + val accessMode = if (conf.get(KUBERNETES_USE_LEGACY_PVC_ACCESS_MODE)) { + "ReadWriteOnce" + } else { + PVC_ACCESS_MODE + } override def configurePod(pod: SparkPod): SparkPod = { val (volumeMounts, volumes) = constructVolumes(conf.volumes).unzip @@ -89,7 +95,7 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) .endMetadata() .withNewSpec() .withStorageClassName(storageClass.get) - .withAccessModes(PVC_ACCESS_MODE) + .withAccessModes(accessMode) .withResources(new ResourceRequirementsBuilder() .withRequests(Map("storage" -> new Quantity(size.get)).asJava).build()) .endSpec() @@ -126,5 +132,5 @@ private[spark] object MountVolumesFeatureStep { val PVC_ON_DEMAND = "OnDemand" val PVC = "PersistentVolumeClaim" val PVC_POSTFIX = "-pvc" - val PVC_ACCESS_MODE = "ReadWriteOnce" + val PVC_ACCESS_MODE = "ReadWriteOncePod" } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index 8151ee34b19ff..73d900f3b9e63 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -472,7 +472,7 @@ class ExecutorPodsAllocator( val reusablePVCs = createdPVCs .filterNot(pvc => pvcsInUse.contains(pvc.getMetadata.getName)) .filter(pvc => now - Instant.parse(pvc.getMetadata.getCreationTimestamp).toEpochMilli - > podAllocationDelay) + > podCreationTimeout) logInfo(s"Found ${reusablePVCs.size} reusable PVCs from ${createdPVCs.size} PVCs") reusablePVCs } catch { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index 5d91070bcab20..5b645e6e97508 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -61,6 +61,9 @@ private[spark] class ExecutorPodsLifecycleManager( private val namespace = conf.get(KUBERNETES_NAMESPACE) + private val sparkContainerName = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME) + .getOrElse(DEFAULT_EXECUTOR_CONTAINER_NAME) + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) snapshotsStore.addSubscriber(eventProcessingInterval) { @@ -240,7 +243,8 @@ private[spark] class ExecutorPodsLifecycleManager( private def findExitCode(podState: FinalPodState): Int = { podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => - containerStatus.getState.getTerminated != null + containerStatus.getName == sparkContainerName && + containerStatus.getState.getTerminated != null }.map { terminatedContainer => terminatedContainer.getState.getTerminated.getExitCode.toInt }.getOrElse(UNKNOWN_EXIT_CODE) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala index de9da0de7da2f..c19955424c052 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -29,6 +29,7 @@ import org.apache.spark.resource.ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID object ExecutorLifecycleTestUtils { val TEST_SPARK_APP_ID = "spark-app-id" + val TEST_SPARK_EXECUTOR_CONTAINER_NAME = "spark-executor" def failedExecutorWithoutDeletion( executorId: Long, rpId: Int = DEFAULT_RESOURCE_PROFILE_ID): Pod = { @@ -37,7 +38,7 @@ object ExecutorLifecycleTestUtils { .withPhase("failed") .withStartTime(Instant.now.toString) .addNewContainerStatus() - .withName("spark-executor") + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) .withImage("k8s-spark") .withNewState() .withNewTerminated() @@ -49,6 +50,38 @@ object ExecutorLifecycleTestUtils { .addNewContainerStatus() .withName("spark-executor-sidecar") .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(2) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def failedExecutorWithSidecarStatusListedFirst( + executorId: Long, rpId: Int = DEFAULT_RESOURCE_PROFILE_ID): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId, rpId)) + .editOrNewStatus() + .withPhase("failed") + .withStartTime(Instant.now.toString) + .addNewContainerStatus() // sidecar status listed before executor's container status + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(2) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) + .withImage("k8s-spark") .withNewState() .withNewTerminated() .withMessage("Failed") @@ -200,7 +233,7 @@ object ExecutorLifecycleTestUtils { .endSpec() .build() val container = new ContainerBuilder() - .withName("spark-executor") + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) .withImage("k8s-spark") .build() SparkPod(pod, container) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index 350a09f0218ba..f202499c849e1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -768,7 +768,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { val pvc = persistentVolumeClaim("pvc-0", "gp2", "200Gi") pvc.getMetadata - .setCreationTimestamp(Instant.now().minus(podAllocationDelay + 1, MILLIS).toString) + .setCreationTimestamp(Instant.now().minus(podCreationTimeout + 1, MILLIS).toString) when(persistentVolumeClaimList.getItems).thenReturn(Seq(pvc).asJava) when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr), meq(kubernetesClient), any(classOf[ResourceProfile]))) @@ -842,15 +842,17 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { val getReusablePVCs = PrivateMethod[mutable.Buffer[PersistentVolumeClaim]](Symbol("getReusablePVCs")) - val pvc1 = persistentVolumeClaim("pvc-0", "gp2", "200Gi") - val pvc2 = persistentVolumeClaim("pvc-1", "gp2", "200Gi") + val pvc1 = persistentVolumeClaim("pvc-1", "gp2", "200Gi") + val pvc2 = persistentVolumeClaim("pvc-2", "gp2", "200Gi") val now = Instant.now() - pvc1.getMetadata.setCreationTimestamp(now.minus(2 * podAllocationDelay, MILLIS).toString) + pvc1.getMetadata.setCreationTimestamp(now.minus(podCreationTimeout + 1, MILLIS).toString) pvc2.getMetadata.setCreationTimestamp(now.toString) when(persistentVolumeClaimList.getItems).thenReturn(Seq(pvc1, pvc2).asJava) - podsAllocatorUnderTest invokePrivate getReusablePVCs("appId", Seq("pvc-1")) + val reusablePVCs = podsAllocatorUnderTest invokePrivate getReusablePVCs("appId", Seq.empty) + assert(reusablePVCs.size == 1) + assert(reusablePVCs.head.getMetadata.getName == "pvc-1") } test("SPARK-41410: Support waitToReusePersistentVolumeClaims") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index c34938caeca70..972cd79088dfc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Config +import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ import org.apache.spark.deploy.k8s.KubernetesUtils._ @@ -60,6 +61,8 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte before { MockitoAnnotations.openMocks(this).close() + val sparkConf = new SparkConf() + .set(KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME, TEST_SPARK_EXECUTOR_CONTAINER_NAME) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() namedExecutorPods = mutable.Map.empty[String, PodResource] when(schedulerBackend.getExecutorsWithRegistrationTs()).thenReturn(Map.empty[String, Long]) @@ -67,7 +70,7 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte when(podOperations.inNamespace(anyString())).thenReturn(podsWithNamespace) when(podsWithNamespace.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) eventHandlerUnderTest = new ExecutorPodsLifecycleManager( - new SparkConf(), + sparkConf, kubernetesClient, snapshotsStore) eventHandlerUnderTest.start(schedulerBackend) @@ -136,6 +139,15 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte .edit(any[UnaryOperator[Pod]]()) } + test("SPARK-49804: Use the exit code of executor container always") { + val failedPod = failedExecutorWithSidecarStatusListedFirst(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod, 1) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + private def exitReasonMessage(execId: Int, failedPod: Pod, exitCode: Int): String = { val reason = Option(failedPod.getStatus.getReason) val message = Option(failedPod.getStatus.getMessage) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 88304c87a79c3..c30823de01360 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -ARG java_image_tag=17-jre +ARG java_image_tag=17-jammy FROM eclipse-temurin:${java_image_tag} diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index c7e543137385b..b72a3daea3c38 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index f52af87a745ca..54ef1f6cee30d 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -45,8 +45,8 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SparkConfPropagateSuite with SecretsTestsSuite with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite - with PVTestsSuite with DepsTestsSuite with DecommissionSuite with RTestsSuite with Logging - with Eventually with Matchers { + with VolumeSuite with PVTestsSuite with DepsTestsSuite with DecommissionSuite with RTestsSuite + with Logging with Eventually with Matchers { import KubernetesSuite._ diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index 4aba11bdb9d8f..4ebf44ce9a4bc 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -75,6 +75,8 @@ private[spark] class KubernetesTestComponents(val kubernetesClient: KubernetesCl .set(UI_ENABLED.key, "true") .set("spark.kubernetes.submission.waitAppCompletion", "false") .set("spark.kubernetes.authenticate.driver.serviceAccountName", serviceAccountName) + .set("spark.kubernetes.driver.request.cores", "0.2") + .set("spark.kubernetes.executor.request.cores", "0.2") } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PVTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PVTestsSuite.scala index 1d373f3f8066e..a699ef674cdcd 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PVTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PVTestsSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Milliseconds, Span} import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ +import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend private[spark] trait PVTestsSuite { k8sSuite: KubernetesSuite => import PVTestsSuite._ @@ -54,6 +55,7 @@ private[spark] trait PVTestsSuite { k8sSuite: KubernetesSuite => setupLocalStorageClass() + val hostname = if (testBackend == MinikubeTestBackend) "minikube" else "docker-desktop" val pvBuilder = new PersistentVolumeBuilder() .withKind("PersistentVolume") .withApiVersion("v1") @@ -72,7 +74,7 @@ private[spark] trait PVTestsSuite { k8sSuite: KubernetesSuite => .withMatchExpressions(new NodeSelectorRequirementBuilder() .withKey("kubernetes.io/hostname") .withOperator("In") - .withValues("minikube", "m01", "docker-for-desktop", "docker-desktop") + .withValues(hostname) .build()).build()) .endRequired() .endNodeAffinity() @@ -166,7 +168,7 @@ private[spark] trait PVTestsSuite { k8sSuite: KubernetesSuite => } } - test("PVs with local hostpath and storageClass on statefulsets", k8sTestTag, MinikubeTag) { + ignore("PVs with local hostpath and storageClass on statefulsets", k8sTestTag, MinikubeTag) { sparkAppConf .set(s"spark.kubernetes.driver.volumes.persistentVolumeClaim.data.mount.path", CONTAINER_MOUNT_PATH) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala index 06d6f7dc100f3..e7143e32db61e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala @@ -496,8 +496,8 @@ private[spark] object VolcanoTestsSuite extends SparkFunSuite { val DRIVER_PG_TEMPLATE_MEMORY_3G = new File( getClass.getResource("/volcano/driver-podgroup-template-memory-3g.yml").getFile ).getAbsolutePath - val DRIVER_REQUEST_CORES = sys.props.get(CONFIG_DRIVER_REQUEST_CORES).getOrElse("1") - val EXECUTOR_REQUEST_CORES = sys.props.get(CONFIG_EXECUTOR_REQUEST_CORES).getOrElse("1") + val DRIVER_REQUEST_CORES = sys.props.get(CONFIG_DRIVER_REQUEST_CORES).getOrElse("0.2") + val EXECUTOR_REQUEST_CORES = sys.props.get(CONFIG_EXECUTOR_REQUEST_CORES).getOrElse("0.2") val VOLCANO_MAX_JOB_NUM = sys.props.get(CONFIG_KEY_VOLCANO_MAX_JOB_NUM).getOrElse("2") val TEMP_DIR = "/tmp/" } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolumeSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolumeSuite.scala new file mode 100644 index 0000000000000..c57e4b4578d6c --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolumeSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import scala.jdk.CollectionConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.scalatest.concurrent.PatienceConfiguration +import org.scalatest.time.{Seconds, Span} + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ +import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend + +private[spark] trait VolumeSuite { k8sSuite: KubernetesSuite => + val IGNORE = Some((Some(PatienceConfiguration.Interval(Span(0, Seconds))), None)) + + private def checkDisk(pod: Pod, path: String, expected: String) = { + eventually(PatienceConfiguration.Timeout(Span(10, Seconds)), INTERVAL) { + implicit val podName: String = pod.getMetadata.getName + implicit val components: KubernetesTestComponents = kubernetesTestComponents + assert(Utils.executeCommand("df", path).contains(expected)) + } + } + + test("A driver-only Spark job with a tmpfs-backed localDir volume", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.master", "local[10]") + .set("spark.kubernetes.local.dirs.tmpfs", "true") + runSparkApplicationAndVerifyCompletion( + containerLocalSparkDistroExamplesJar, + SPARK_PI_MAIN_CLASS, + Seq("local[10]", "Pi is roughly 3"), + Seq(), + Array.empty[String], + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + val path = driverPod.getSpec.getContainers.get(0).getEnv.asScala + .filter(_.getName == "SPARK_LOCAL_DIRS").map(_.getValue).head + checkDisk(driverPod, path, "tmpfs") + }, + _ => (), + isJVM = true, + executorPatience = IGNORE) + } + + test("A driver-only Spark job with a tmpfs-backed emptyDir data volume", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.master", "local[10]") + .set("spark.kubernetes.driver.volumes.emptyDir.data.mount.path", "/data") + .set("spark.kubernetes.driver.volumes.emptyDir.data.options.medium", "Memory") + .set("spark.kubernetes.driver.volumes.emptyDir.data.options.sizeLimit", "1G") + runSparkApplicationAndVerifyCompletion( + containerLocalSparkDistroExamplesJar, + SPARK_PI_MAIN_CLASS, + Seq("local[10]", "Pi is roughly 3"), + Seq(), + Array.empty[String], + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkDisk(driverPod, "/data", "tmpfs") + }, + _ => (), + isJVM = true, + executorPatience = IGNORE) + } + + test("A driver-only Spark job with a disk-backed emptyDir volume", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.master", "local[10]") + .set("spark.kubernetes.driver.volumes.emptyDir.data.mount.path", "/data") + .set("spark.kubernetes.driver.volumes.emptyDir.data.mount.sizeLimit", "1G") + runSparkApplicationAndVerifyCompletion( + containerLocalSparkDistroExamplesJar, + SPARK_PI_MAIN_CLASS, + Seq("local[10]", "Pi is roughly 3"), + Seq(), + Array.empty[String], + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkDisk(driverPod, "/data", "/dev/") + }, + _ => (), + isJVM = true, + executorPatience = IGNORE) + } + + test("A driver-only Spark job with an OnDemand PVC volume", k8sTestTag) { + val storageClassName = if (testBackend == MinikubeTestBackend) "standard" else "hostpath" + val DRIVER_PREFIX = "spark.kubernetes.driver.volumes.persistentVolumeClaim" + sparkAppConf + .set("spark.kubernetes.driver.master", "local[10]") + .set(s"$DRIVER_PREFIX.data.options.claimName", "OnDemand") + .set(s"$DRIVER_PREFIX.data.options.storageClass", storageClassName) + .set(s"$DRIVER_PREFIX.data.options.sizeLimit", "1Gi") + .set(s"$DRIVER_PREFIX.data.mount.path", "/data") + .set(s"$DRIVER_PREFIX.data.mount.readOnly", "false") + runSparkApplicationAndVerifyCompletion( + containerLocalSparkDistroExamplesJar, + SPARK_PI_MAIN_CLASS, + Seq("local[10]", "Pi is roughly 3"), + Seq(), + Array.empty[String], + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkDisk(driverPod, "/data", "/dev/") + }, + _ => (), + isJVM = true, + executorPatience = IGNORE) + } + + test("A Spark job with tmpfs-backed localDir volumes", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.local.dirs.tmpfs", "true") + runSparkApplicationAndVerifyCompletion( + containerLocalSparkDistroExamplesJar, + SPARK_PI_MAIN_CLASS, + Seq("Pi is roughly 3"), + Seq(), + Array.empty[String], + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + val path = driverPod.getSpec.getContainers.get(0).getEnv.asScala + .filter(_.getName == "SPARK_LOCAL_DIRS").map(_.getValue).head + checkDisk(driverPod, path, "tmpfs") + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + val path = executorPod.getSpec.getContainers.get(0).getEnv.asScala + .filter(_.getName == "SPARK_LOCAL_DIRS").map(_.getValue).head + checkDisk(executorPod, path, "tmpfs") + }, + isJVM = true) + } + + test("A Spark job with two executors with OnDemand PVC volumes", k8sTestTag) { + val storageClassName = if (testBackend == MinikubeTestBackend) "standard" else "hostpath" + val EXECUTOR_PREFIX = "spark.kubernetes.executor.volumes.persistentVolumeClaim" + sparkAppConf + .set("spark.executor.instances", "2") + .set(s"$EXECUTOR_PREFIX.data.options.claimName", "OnDemand") + .set(s"$EXECUTOR_PREFIX.data.options.storageClass", storageClassName) + .set(s"$EXECUTOR_PREFIX.data.options.sizeLimit", "1Gi") + .set(s"$EXECUTOR_PREFIX.data.mount.path", "/data") + .set(s"$EXECUTOR_PREFIX.data.mount.readOnly", "false") + runSparkApplicationAndVerifyCompletion( + containerLocalSparkDistroExamplesJar, + SPARK_PI_MAIN_CLASS, + Seq("Pi is roughly 3"), + Seq(), + Array.empty[String], + _ => (), + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkDisk(executorPod, "/data", "/dev/") + }, + isJVM = true) + } +} diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 1dda41e085178..31377cbda5d8e 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index fc07d304a038f..d7f3786e1050f 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml @@ -54,12 +54,12 @@ org.bouncycastle - bcprov-jdk15on + bcprov-jdk18on test org.bouncycastle - bcpkix-jdk15on + bcpkix-jdk18on test diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8257a08fd14e2..9c8a6dd8db069 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -883,14 +883,13 @@ private[spark] class Client( /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv( + private[yarn] def setupLaunchEnv( stagingDirPath: Path, pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString - env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() env("SPARK_PREFER_IPV6") = Utils.preferIPv6.toString // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* @@ -900,6 +899,10 @@ private[spark] class Client( .map { case (k, v) => (k.substring(amEnvPrefix.length), v) } .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) } + if (!env.contains("SPARK_USER")) { + env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() + } + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH // of the container processes too. Add all non-.py files directly to PYTHONPATH. // diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala index 0dd4e0a6c8ad9..f9aa11c4d48d6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala @@ -168,7 +168,7 @@ private object ResourceRequestHelper extends Logging { if (numResourceErrors < 2) { logWarning(s"YARN doesn't know about resource $name, your resource discovery " + s"has to handle properly discovering and isolating the resource! Error: " + - s"${e.getCause.getMessage}") + s"${e.getMessage}") numResourceErrors += 1 } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 19c06f957318b..5fccc8c9ff47c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -384,19 +384,28 @@ private[yarn] class YarnAllocator( this.numLocalityAwareTasksPerResourceProfileId = numLocalityAwareTasksPerResourceProfileId this.hostToLocalTaskCountPerResourceProfileId = hostToLocalTaskCountPerResourceProfileId - val res = resourceProfileToTotalExecs.map { case (rp, numExecs) => - createYarnResourceForResourceProfile(rp) - if (numExecs != getOrUpdateTargetNumExecutorsForRPId(rp.id)) { - logInfo(s"Driver requested a total number of $numExecs executor(s) " + - s"for resource profile id: ${rp.id}.") - targetNumExecutorsPerResourceProfileId(rp.id) = numExecs - allocatorNodeHealthTracker.setSchedulerExcludedNodes(excludedNodes) - true - } else { - false + if (resourceProfileToTotalExecs.isEmpty) { + // Set target executor number to 0 to cancel pending allocate request. + targetNumExecutorsPerResourceProfileId.keys.foreach { rp => + targetNumExecutorsPerResourceProfileId(rp) = 0 + } + allocatorNodeHealthTracker.setSchedulerExcludedNodes(excludedNodes) + true + } else { + val res = resourceProfileToTotalExecs.map { case (rp, numExecs) => + createYarnResourceForResourceProfile(rp) + if (numExecs != getOrUpdateTargetNumExecutorsForRPId(rp.id)) { + logInfo(s"Driver requested a total number of $numExecs executor(s) " + + s"for resource profile id: ${rp.id}.") + targetNumExecutorsPerResourceProfileId(rp.id) = numExecs + allocatorNodeHealthTracker.setSchedulerExcludedNodes(excludedNodes) + true + } else { + false + } } + res.exists(_ == true) } - res.exists(_ == true) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index b7fb409ebc359..a59a0112c78dc 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -29,6 +29,7 @@ import scala.collection.mutable.{HashMap => MutableHashMap} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords.{GetNewApplicationResponse, SubmitApplicationRequest} import org.apache.hadoop.yarn.api.records._ @@ -670,6 +671,21 @@ class ClientSuite extends SparkFunSuite assertUserClasspathUrls(cluster = true, replacementRootPath) } + test("default app master SPARK_USER") { + val sparkConf = new SparkConf() + val client = createClient(sparkConf) + val env = client.setupLaunchEnv(new Path("/staging/dir/path"), Seq()) + env("SPARK_USER") should be (UserGroupInformation.getCurrentUser().getShortUserName()) + } + + test("override app master SPARK_USER") { + val sparkConf = new SparkConf() + .set("spark.yarn.appMasterEnv.SPARK_USER", "overrideuser") + val client = createClient(sparkConf) + val env = client.setupLaunchEnv(new Path("/staging/dir/path"), Seq()) + env("SPARK_USER") should be ("overrideuser") + } + private val matching = Seq( ("files URI match test1", "file:///file1", "file:///file2"), ("files URI match test2", "file:///c:file1", "file://c:file2"), diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 3cfd5acfe2b56..28d205f03e0fa 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -31,7 +31,7 @@ # SPARK_NO_DAEMONIZE If set, will run the proposed command in the foreground. It will not output a PID file. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop|submit|status) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|submit|decommission|status) " # if no args specified, show usage if [ $# -le 1 ]; then diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 74e8480deaff7..0ccd937e72e88 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -227,6 +227,18 @@ This file is divided into 3 sections: ]]> + + new.*ParVector + + + (\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\))) org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 85dbc499fbde5..04128216be073 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -951,7 +951,6 @@ primaryExpression | qualifiedName DOT ASTERISK #star | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor | LEFT_PAREN query RIGHT_PAREN #subqueryExpression - | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument (COMMA argument+=functionArgument)*)? RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? @@ -1176,6 +1175,7 @@ qualifiedNameList functionName : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | identFunc=IDENTIFIER_KW // IDENTIFIER itself is also a valid function name. | qualifiedName | FILTER | LEFT diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 3d536b735db59..191ccc5254404 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -130,10 +130,13 @@ object JavaTypeInference { // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(c) + // add type variables from inheritance hierarchy of the class + val classTV = JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap ++ + typeVariables // Note that the fields are ordered by name. val fields = properties.map { property => val readMethod = property.getReadMethod - val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, typeVariables) + val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, classTV) // The existence of `javax.annotation.Nonnull`, means this field is not nullable. val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) EncoderField( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index 698e7b37a9ef0..980eee9390d09 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -237,6 +237,9 @@ trait SparkDateTimeUtils { def toJavaTimestamp(micros: Long): Timestamp = toJavaTimestampNoRebase(rebaseGregorianToJulianMicros(micros)) + def toJavaTimestamp(timeZoneId: String, micros: Long): Timestamp = + toJavaTimestampNoRebase(rebaseGregorianToJulianMicros(timeZoneId, micros)) + /** * Converts microseconds since the epoch to an instance of `java.sql.Timestamp`. * @@ -273,6 +276,9 @@ trait SparkDateTimeUtils { def fromJavaTimestamp(t: Timestamp): Long = rebaseJulianToGregorianMicros(fromJavaTimestampNoRebase(t)) + def fromJavaTimestamp(timeZoneId: String, t: Timestamp): Long = + rebaseJulianToGregorianMicros(timeZoneId, fromJavaTimestampNoRebase(t)) + /** * Converts an instance of `java.sql.Timestamp` to the number of microseconds since * 1970-01-01T00:00:00.000000Z. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 8a288d0e9f3a3..07b32af5c85e8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -167,8 +167,9 @@ class Iso8601TimestampFormatter( override def parseOptional(s: String): Option[Long] = { try { - val parsed = formatter.parseUnresolved(s, new ParsePosition(0)) - if (parsed != null) { + val parsePosition = new ParsePosition(0) + val parsed = formatter.parseUnresolved(s, parsePosition) + if (parsed != null && s.length == parsePosition.getIndex) { Some(extractMicros(parsed)) } else { None @@ -196,8 +197,9 @@ class Iso8601TimestampFormatter( override def parseWithoutTimeZoneOptional(s: String, allowTimeZone: Boolean): Option[Long] = { try { - val parsed = formatter.parseUnresolved(s, new ParsePosition(0)) - if (parsed != null) { + val parsePosition = new ParsePosition(0) + val parsed = formatter.parseUnresolved(s, parsePosition) + if (parsed != null && s.length == parsePosition.getIndex) { Some(extractMicrosNTZ(s, parsed, allowTimeZone)) } else { None @@ -412,10 +414,14 @@ class LegacyFastTimestampFormatter( override def parseOptional(s: String): Option[Long] = { cal.clear() // Clear the calendar because it can be re-used many times - if (fastDateFormat.parse(s, new ParsePosition(0), cal)) { - Some(extractMicros(cal)) - } else { - None + try { + if (fastDateFormat.parse(s, new ParsePosition(0), cal)) { + Some(extractMicros(cal)) + } else { + None + } + } catch { + case NonFatal(_) => None } } @@ -423,11 +429,11 @@ class LegacyFastTimestampFormatter( val micros = cal.getMicros() cal.set(Calendar.MILLISECOND, 0) val julianMicros = Math.addExact(millisToMicros(cal.getTimeInMillis), micros) - rebaseJulianToGregorianMicros(julianMicros) + rebaseJulianToGregorianMicros(TimeZone.getTimeZone(zoneId), julianMicros) } override def format(timestamp: Long): String = { - val julianMicros = rebaseGregorianToJulianMicros(timestamp) + val julianMicros = rebaseGregorianToJulianMicros(TimeZone.getTimeZone(zoneId), timestamp) cal.setTimeInMillis(Math.floorDiv(julianMicros, MICROS_PER_SECOND) * MILLIS_PER_SECOND) cal.setMicros(Math.floorMod(julianMicros, MICROS_PER_SECOND)) fastDateFormat.format(cal) @@ -437,7 +443,7 @@ class LegacyFastTimestampFormatter( if (ts.getNanos == 0) { fastDateFormat.format(ts) } else { - format(fromJavaTimestamp(ts)) + format(fromJavaTimestamp(zoneId.getId, ts)) } } @@ -461,7 +467,7 @@ class LegacySimpleTimestampFormatter( } override def parse(s: String): Long = { - fromJavaTimestamp(new Timestamp(sdf.parse(s).getTime)) + fromJavaTimestamp(zoneId.getId, new Timestamp(sdf.parse(s).getTime)) } override def parseOptional(s: String): Option[Long] = { @@ -469,12 +475,12 @@ class LegacySimpleTimestampFormatter( if (date == null) { None } else { - Some(fromJavaTimestamp(new Timestamp(date.getTime))) + Some(fromJavaTimestamp(zoneId.getId, new Timestamp(date.getTime))) } } override def format(us: Long): String = { - sdf.format(toJavaTimestamp(us)) + sdf.format(toJavaTimestamp(zoneId.getId, us)) } override def format(ts: Timestamp): String = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index f63fc8c4785bc..d9bfb26906420 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -283,7 +283,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { def nestedTypeMissingElementTypeError( dataType: String, ctx: PrimitiveDataTypeContext): Throwable = { - dataType match { + dataType.toUpperCase(Locale.ROOT) match { case "ARRAY" => new ParseException( errorClass = "INCOMPLETE_TYPE_DEFINITION.ARRAY", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala index d746e9037ec48..5ec72b83837ee 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.internal import java.util.TimeZone -import java.util.concurrent.atomic.AtomicReference import scala.util.Try @@ -48,25 +47,14 @@ private[sql] trait SqlApiConf { private[sql] object SqlApiConf { // Shared keys. - val ANSI_ENABLED_KEY: String = "spark.sql.ansi.enabled" - val LEGACY_TIME_PARSER_POLICY_KEY: String = "spark.sql.legacy.timeParserPolicy" - val CASE_SENSITIVE_KEY: String = "spark.sql.caseSensitive" - val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone" - val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = "spark.sql.session.localRelationCacheThreshold" + val ANSI_ENABLED_KEY: String = SqlApiConfHelper.ANSI_ENABLED_KEY + val LEGACY_TIME_PARSER_POLICY_KEY: String = SqlApiConfHelper.LEGACY_TIME_PARSER_POLICY_KEY + val CASE_SENSITIVE_KEY: String = SqlApiConfHelper.CASE_SENSITIVE_KEY + val SESSION_LOCAL_TIMEZONE_KEY: String = SqlApiConfHelper.SESSION_LOCAL_TIMEZONE_KEY + val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = + SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY - /** - * Defines a getter that returns the [[SqlApiConf]] within scope. - */ - private val confGetter = new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf) - - /** - * Sets the active config getter. - */ - private[sql] def setConfGetter(getter: () => SqlApiConf): Unit = { - confGetter.set(getter) - } - - def get: SqlApiConf = confGetter.get()() + def get: SqlApiConf = SqlApiConfHelper.getConfGetter.get()() // Force load SQLConf. This will trigger the installation of a confGetter that points to SQLConf. Try(SparkClassUtils.classForName("org.apache.spark.sql.internal.SQLConf$")) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala new file mode 100644 index 0000000000000..79b6cb9231c51 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import java.util.concurrent.atomic.AtomicReference + +/** + * SqlApiConfHelper is created to avoid a deadlock during a concurrent access to SQLConf and + * SqlApiConf, which is because SQLConf and SqlApiConf tries to load each other upon + * initializations. SqlApiConfHelper is private to sql package and is not supposed to be + * accessed by end users. Variables and methods within SqlApiConfHelper are defined to + * be used by SQLConf and SqlApiConf only. + */ +private[sql] object SqlApiConfHelper { + // Shared keys. + val ANSI_ENABLED_KEY: String = "spark.sql.ansi.enabled" + val LEGACY_TIME_PARSER_POLICY_KEY: String = "spark.sql.legacy.timeParserPolicy" + val CASE_SENSITIVE_KEY: String = "spark.sql.caseSensitive" + val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone" + val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = "spark.sql.session.localRelationCacheThreshold" + + val confGetter: AtomicReference[() => SqlApiConf] = { + new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf) + } + + def getConfGetter: AtomicReference[() => SqlApiConf] = confGetter + + /** + * Sets the active config getter. + */ + def setConfGetter(getter: () => SqlApiConf): Unit = { + confGetter.set(getter) + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afe73635a6824..77e9aa06c830c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -499,7 +499,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def / (that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, - DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) + DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode)) def % (that: Decimal): Decimal = if (that.isZero) null @@ -547,7 +547,11 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) + // SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results + // because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer + // precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down + // the last extra digit. + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN) private[sql] val ZERO = Decimal(0) private[sql] val ONE = Decimal(1) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8edc7cf370b7d..8fd7f47b34624 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import scala.collection.{mutable, Map} import scala.util.Try import scala.util.control.NonFatal @@ -476,8 +478,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be * thrown. */ - private[sql] def merge(that: StructType): StructType = - StructType.merge(this, that).asInstanceOf[StructType] + private[sql] def merge(that: StructType, caseSensitive: Boolean = true): StructType = + StructType.merge(this, that, caseSensitive).asInstanceOf[StructType] override private[spark] def asNullable: StructType = { val newFields = fields.map { @@ -561,16 +563,20 @@ object StructType extends AbstractDataType { StructType(newFields) }) - private[sql] def merge(left: DataType, right: DataType): DataType = + private[sql] def merge(left: DataType, right: DataType, caseSensitive: Boolean = true): DataType = mergeInternal(left, right, (s1: StructType, s2: StructType) => { val leftFields = s1.fields val rightFields = s2.fields val newFields = mutable.ArrayBuffer.empty[StructField] - val rightMapped = fieldsMap(rightFields) + def normalize(name: String): String = { + if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + } + + val rightMapped = fieldsMap(rightFields, caseSensitive) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightMapped.get(leftName) + rightMapped.get(normalize(leftName)) .map { case rightField @ StructField(rightName, rightType, rightNullable, _) => try { leftField.copy( @@ -588,9 +594,9 @@ object StructType extends AbstractDataType { .foreach(newFields += _) } - val leftMapped = fieldsMap(leftFields) + val leftMapped = fieldsMap(leftFields, caseSensitive) rightFields - .filterNot(f => leftMapped.get(f.name).nonEmpty) + .filterNot(f => leftMapped.contains(normalize(f.name))) .foreach { f => newFields += f } @@ -643,11 +649,15 @@ object StructType extends AbstractDataType { throw DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right) } - private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { + private[sql] def fieldsMap( + fields: Array[StructField], + caseSensitive: Boolean = true): Map[String, StructField] = { // Mimics the optimization of breakOut, not present in Scala 2.13, while working in 2.12 val map = mutable.Map[String, StructField]() map.sizeHint(fields.length) - fields.foreach(s => map.put(s.name, s)) + fields.foreach { s => + if (caseSensitive) map.put(s.name, s) else map.put(s.name.toLowerCase(Locale.ROOT), s) + } map } diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index f9cd2dc677adf..0564a6be7432a 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index f6686d2e4d3b6..786821514822e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -38,7 +38,7 @@ @Evolving public abstract class DelegatingCatalogExtension implements CatalogExtension { - private CatalogPlugin delegate; + protected CatalogPlugin delegate; @Override public final void setDelegateCatalog(CatalogPlugin delegate) { @@ -51,7 +51,7 @@ public String name() { } @Override - public final void initialize(String name, CaseInsensitiveStringMap options) {} + public void initialize(String name, CaseInsensitiveStringMap options) {} @Override public Set capabilities() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java index 4337a7c615208..3094b0cf1bbda 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java @@ -80,7 +80,8 @@ StagedTable stageCreate( * @param columns the column of the new table * @param partitions transforms to use for partitioning data in the table * @param properties a string map of table properties - * @return metadata for the new table + * @return metadata for the new table. This can be null if the catalog does not support atomic + * creation for this table. Spark will call {@link #loadTable(Identifier)} later. * @throws TableAlreadyExistsException If a table or view already exists for the identifier * @throws UnsupportedOperationException If a requested partition transform is not supported * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) @@ -128,7 +129,8 @@ StagedTable stageReplace( * @param columns the columns of the new table * @param partitions transforms to use for partitioning data in the table * @param properties a string map of table properties - * @return metadata for the new table + * @return metadata for the new table. This can be null if the catalog does not support atomic + * creation for this table. Spark will call {@link #loadTable(Identifier)} later. * @throws UnsupportedOperationException If a requested partition transform is not supported * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) * @throws NoSuchTableException If the table does not exist @@ -176,7 +178,8 @@ StagedTable stageCreateOrReplace( * @param columns the columns of the new table * @param partitions transforms to use for partitioning data in the table * @param properties a string map of table properties - * @return metadata for the new table + * @return metadata for the new table. This can be null if the catalog does not support atomic + * creation for this table. Spark will call {@link #loadTable(Identifier)} later. * @throws UnsupportedOperationException If a requested partition transform is not supported * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsMetadataColumns.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsMetadataColumns.java index 894184dbcc82d..e42424268b44d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsMetadataColumns.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsMetadataColumns.java @@ -58,8 +58,8 @@ public interface SupportsMetadataColumns extends Table { * Determines how this data source handles name conflicts between metadata and data columns. *

    * If true, spark will automatically rename the metadata column to resolve the conflict. End users - * can reliably select metadata columns (renamed or not) with {@link Dataset.metadataColumn}, and - * internal code can use {@link MetadataAttributeWithLogicalName} to extract the logical name from + * can reliably select metadata columns (renamed or not) with {@code Dataset.metadataColumn}, and + * internal code can use {@code MetadataAttributeWithLogicalName} to extract the logical name from * a metadata attribute. *

    * If false, the data column will hide the metadata column. It is recommended that Table diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index d99e7e14b0117..d1951a7f7fbf3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -46,13 +46,14 @@ public interface TableCatalog extends CatalogPlugin { /** * A reserved property to specify the location of the table. The files of the table - * should be under this location. + * should be under this location. The location is a Hadoop Path string. */ String PROP_LOCATION = "location"; /** * A reserved property to indicate that the table location is managed, not user-specified. - * If this property is "true", SHOW CREATE TABLE will not generate the LOCATION clause. + * If this property is "true", it means it's a managed table even if it has a location. As an + * example, SHOW CREATE TABLE will not generate the LOCATION clause. */ String PROP_IS_MANAGED_LOCATION = "is_managed_location"; @@ -109,6 +110,26 @@ public interface TableCatalog extends CatalogPlugin { */ Table loadTable(Identifier ident) throws NoSuchTableException; + /** + * Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data + * into this table later. + *

    + * If the catalog supports views and contains a view for the identifier and not a table, this + * must throw {@link NoSuchTableException}. + * + * @param ident a table identifier + * @param writePrivileges + * @return the table's metadata + * @throws NoSuchTableException If the table doesn't exist or is a view + * + * @since 3.5.3 + */ + default Table loadTable( + Identifier ident, + Set writePrivileges) throws NoSuchTableException { + return loadTable(ident); + } + /** * Load table metadata of a specific version by {@link Identifier identifier} from the catalog. *

    @@ -187,7 +208,9 @@ Table createTable( * @param columns the columns of the new table. * @param partitions transforms to use for partitioning data in the table * @param properties a string map of table properties - * @return metadata for the new table + * @return metadata for the new table. This can be null if getting the metadata for the new table + * is expensive. Spark will call {@link #loadTable(Identifier)} if needed (e.g. CTAS). + * * @throws TableAlreadyExistsException If a table or view already exists for the identifier * @throws UnsupportedOperationException If a requested partition transform is not supported * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) @@ -221,7 +244,9 @@ default boolean useNullableQuerySchema() { * * @param ident a table identifier * @param changes changes to apply to the table - * @return updated metadata for the table + * @return updated metadata for the table. This can be null if getting the metadata for the + * updated table is expensive. Spark always discard the returned table here. + * * @throws NoSuchTableException If the table doesn't exist or is a view * @throws IllegalArgumentException If any change is rejected by the implementation. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java new file mode 100644 index 0000000000000..ca2d4ba9e7b4e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +/** + * The table write privileges that will be provided when loading a table. + * + * @since 3.5.3 + */ +public enum TableWritePrivilege { + /** + * The privilege for adding rows to the table. + */ + INSERT, + + /** + * The privilege for changing existing rows in th table. + */ + UPDATE, + + /** + * The privilege for deleting rows from the table. + */ + DELETE +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 9ca0fe4787f10..e170951bfa284 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -48,6 +48,32 @@ */ public class V2ExpressionSQLBuilder { + /** + * Escape the special chars for like pattern. + * + * Note: This method adopts the escape representation within Spark and is not bound to any JDBC + * dialect. JDBC dialect should overwrite this API if the underlying database have more special + * chars other than _ and %. + */ + protected String escapeSpecialCharsForLikePattern(String str) { + StringBuilder builder = new StringBuilder(); + + for (char c : str.toCharArray()) { + switch (c) { + case '_': + builder.append("\\_"); + break; + case '%': + builder.append("\\%"); + break; + default: + builder.append(c); + } + } + + return builder.toString(); + } + public String build(Expression expr) { if (expr instanceof Literal) { return visitLiteral((Literal) expr); @@ -247,21 +273,21 @@ protected String visitStartsWith(String l, String r) { // Remove quotes at the beginning and end. // e.g. converts "'str'" to "str". String value = r.substring(1, r.length() - 1); - return l + " LIKE '" + value + "%'"; + return l + " LIKE '" + escapeSpecialCharsForLikePattern(value) + "%' ESCAPE '\\'"; } protected String visitEndsWith(String l, String r) { // Remove quotes at the beginning and end. // e.g. converts "'str'" to "str". String value = r.substring(1, r.length() - 1); - return l + " LIKE '%" + value + "'"; + return l + " LIKE '%" + escapeSpecialCharsForLikePattern(value) + "' ESCAPE '\\'"; } protected String visitContains(String l, String r) { // Remove quotes at the beginning and end. // e.g. converts "'str'" to "str". String value = r.substring(1, r.length() - 1); - return l + " LIKE '%" + value + "%'"; + return l + " LIKE '%" + escapeSpecialCharsForLikePattern(value) + "%' ESCAPE '\\'"; } private String inputToSQL(Expression input) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index bd7c3d7c0fd49..e58f36641d298 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -47,26 +47,43 @@ public int numElements() { return length; } + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(offset + i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + @Override public ArrayData copy() { DataType dt = data.dataType(); if (dt instanceof BooleanType) { - return UnsafeArrayData.fromPrimitiveArray(toBooleanArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); } else if (dt instanceof ByteType) { - return UnsafeArrayData.fromPrimitiveArray(toByteArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); } else if (dt instanceof ShortType) { - return UnsafeArrayData.fromPrimitiveArray(toShortArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); } else if (dt instanceof IntegerType || dt instanceof DateType || dt instanceof YearMonthIntervalType) { - return UnsafeArrayData.fromPrimitiveArray(toIntArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); } else if (dt instanceof LongType || dt instanceof TimestampType || dt instanceof DayTimeIntervalType) { - return UnsafeArrayData.fromPrimitiveArray(toLongArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); } else if (dt instanceof FloatType) { - return UnsafeArrayData.fromPrimitiveArray(toFloatArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); } else if (dt instanceof DoubleType) { - return UnsafeArrayData.fromPrimitiveArray(toDoubleArray()); + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); } else { return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. } diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 3e545f745baee..c18679330f3a4 100644 --- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable +import scala.collection.{mutable, GenTraversableOnce} import scala.collection.mutable.ArrayBuffer object ExpressionSet { @@ -108,12 +108,31 @@ class ExpressionSet protected( newSet } + /** + * SPARK-47897: In Scala 2.12, the `SetLike.++` method iteratively calls `+` method. + * `ExpressionSet.+` is expensive, so we override `++`. + */ + override def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = { + val newSet = clone() + elems.foreach(newSet.add) + newSet + } + override def -(elem: Expression): ExpressionSet = { val newSet = clone() newSet.remove(elem) newSet } + /** + * SPARK-47897: We need to override `--` like `++`. + */ + override def --(elems: GenTraversableOnce[Expression]): ExpressionSet = { + val newSet = clone() + elems.foreach(newSet.remove) + newSet + } + def map(f: Expression => Expression): ExpressionSet = { val newSet = new ExpressionSet() this.iterator.foreach(elem => newSet.add(f(elem))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 16a7d7ff06526..0b88d5a4130e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -390,7 +390,9 @@ object DeserializerBuildHelper { CreateExternalRow(convertedFields, enc.schema)) case JavaBeanEncoder(tag, fields) => - val setters = fields.map { f => + val setters = fields + .filter(_.writeMethod.isDefined) + .map { f => val newTypePath = walkedTypePath.recordField( f.enc.clsTag.runtimeClass.getName, f.name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala index 429ce805bf2c4..75573cb72d839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala @@ -25,7 +25,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * An [[InternalRow]] that projects particular columns from another [[InternalRow]] without copying * the underlying data. */ -case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) extends InternalRow { +case class ProjectingInternalRow(schema: StructType, + colOrdinals: IndexedSeq[Int]) extends InternalRow { assert(schema.size == colOrdinals.size) private var row: InternalRow = _ @@ -115,3 +116,9 @@ case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) exte row.get(colOrdinals(ordinal), dataType) } } + +object ProjectingInternalRow { + def apply(schema: StructType, colOrdinals: Seq[Int]): ProjectingInternalRow = { + new ProjectingInternalRow(schema, colOrdinals.toIndexedSeq) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 27090ff6fa5d6..cd087514f4be3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -450,10 +450,15 @@ object SerializerBuildHelper { private def validateAndSerializeElement( enc: AgnosticEncoder[_], nullable: Boolean): Expression => Expression = { input => + val expected = enc match { + case OptionEncoder(_) => lenientExternalDataTypeFor(enc) + case _ => enc.dataType + } + expressionWithNullSafety( createSerializer( enc, - ValidateExternalType(input, enc.dataType, lenientExternalDataTypeFor(enc))), + ValidateExternalType(input, expected, lenientExternalDataTypeFor(enc))), nullable, WalkedTypePath()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6c5d19f58ac25..463bd3c3a8a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import scala.util.{Failure, Random, Success, Try} import org.apache.spark.sql.AnalysisException @@ -78,7 +79,7 @@ object SimpleAnalyzer extends Analyzer( override def resolver: Resolver = caseSensitiveResolution } -object FakeV2SessionCatalog extends TableCatalog with FunctionCatalog { +object FakeV2SessionCatalog extends TableCatalog with FunctionCatalog with SupportsNamespaces { private def fail() = throw new UnsupportedOperationException override def listTables(namespace: Array[String]): Array[Identifier] = fail() override def loadTable(ident: Identifier): Table = { @@ -92,10 +93,23 @@ object FakeV2SessionCatalog extends TableCatalog with FunctionCatalog { override def alterTable(ident: Identifier, changes: TableChange*): Table = fail() override def dropTable(ident: Identifier): Boolean = fail() override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = fail() - override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = fail() + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} override def name(): String = CatalogManager.SESSION_CATALOG_NAME override def listFunctions(namespace: Array[String]): Array[Identifier] = fail() override def loadFunction(ident: Identifier): UnboundFunction = fail() + override def listNamespaces(): Array[Array[String]] = fail() + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = fail() + override def loadNamespaceMetadata(namespace: Array[String]): util.Map[String, String] = { + if (namespace.length == 1) { + mutable.HashMap[String, String]().asJava + } else { + throw new NoSuchNamespaceException(namespace) + } + } + override def createNamespace( + namespace: Array[String], metadata: util.Map[String, String]): Unit = fail() + override def alterNamespace(namespace: Array[String], changes: NamespaceChange*): Unit = fail() + override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = fail() } /** @@ -255,7 +269,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor TypeCoercion.typeCoercionRules } - override def batches: Seq[Batch] = Seq( + private def earlyBatches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. // However, when manipulating deeply nested schema, `UpdateFields` expression tree could be @@ -275,7 +289,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs), + KeepLegacyOutputs) + ) + + override def batches: Seq[Batch] = earlyBatches ++ Seq( Batch("Resolution", fixedPoint, new ResolveCatalogs(catalogManager) :: ResolveInsertInto :: @@ -307,6 +324,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: ResolveOutputRelation :: + new ResolveDataFrameDropColumns(catalogManager) :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -318,12 +336,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveTimeZone :: ResolveRandomSeed :: ResolveBinaryArithmetic :: - ResolveIdentifierClause :: + new ResolveIdentifierClause(earlyBatches) :: ResolveUnion :: ResolveRowLevelCommandAssignments :: - RewriteDeleteFromTable :: - RewriteUpdateTable :: - RewriteMergeIntoTable :: typeCoercionRules ++ Seq( ResolveWithCTE, @@ -337,17 +352,32 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor new ResolveHints.RemoveAllHints), Batch("Nondeterministic", Once, PullOutNondeterministic), - Batch("UDF", Once, + Batch("ScalaUDF Null Handling", fixedPoint, + // `HandleNullInputsForUDF` may wrap the `ScalaUDF` with `If` expression to return null for + // null inputs, so the result can be null even if `ScalaUDF#nullable` is false. We need to + // run `UpdateAttributeNullability` to update nullability of the UDF output attribute in + // downstream operators. After updating attribute nullability, `ScalaUDF`s in downstream + // operators may need null handling as well, so we should run these two rules repeatedly. HandleNullInputsForUDF, - ResolveEncodersInUDF), - Batch("UpdateNullability", Once, UpdateAttributeNullability), + Batch("UDF", Once, + ResolveEncodersInUDF), + // The rewrite rules might move resolved query plan into subquery. Once the resolved plan + // contains ScalaUDF, their encoders won't be resolved if `ResolveEncodersInUDF` is not + // applied before the rewrite rules. So we need to apply `ResolveEncodersInUDF` before the + // rewrite rules. + Batch("DML rewrite", fixedPoint, + RewriteDeleteFromTable, + RewriteUpdateTable, + RewriteMergeIntoTable), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases), Batch("HandleSpecialCommand", Once, - HandleSpecialCommand) + HandleSpecialCommand), + Batch("Remove watermark for batch query", Once, + EliminateEventTimeWatermark) ) /** @@ -1216,7 +1246,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor options: CaseInsensitiveStringMap, isStreaming: Boolean): Option[LogicalPlan] = { table.map { - case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog) => + // To utilize this code path to execute V1 commands, e.g. INSERT, + // either it must be session catalog, or tracksPartitionsInCatalog + // must be false so it does not require use catalog to manage partitions. + // Obviously we cannot execute V1Table by V1 code path if the table + // is not from session catalog and the table still requires its catalog + // to manage partitions. + case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog) + || !v1Table.catalogTable.tracksPartitionsInCatalog => if (isStreaming) { if (v1Table.v1Table.tableType == CatalogTableType.VIEW) { throw QueryCompilationErrors.permanentViewNotSupportedByStreamingReadingAPIError( @@ -1259,16 +1296,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor expandIdentifier(u.multipartIdentifier) match { case CatalogAndIdentifier(catalog, ident) => val key = ((catalog.name +: ident.namespace :+ ident.name).toSeq, timeTravelSpec) - AnalysisContext.get.relationCache.get(key).map(_.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - }).orElse { - val table = CatalogV2Util.loadTable(catalog, ident, timeTravelSpec) - val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) + AnalysisContext.get.relationCache.get(key).map { cache => + val cachedRelation = cache.transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + val cachedConnectRelation = cachedRelation.clone() + cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + cachedConnectRelation + }.getOrElse(cachedRelation) + }.orElse { + val writePrivilegesString = + Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) + val table = CatalogV2Util.loadTable( + catalog, ident, timeTravelSpec, writePrivilegesString) + val loaded = createRelation( + catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) - loaded + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + loaded.map { loadedRelation => + val loadedConnectRelation = loadedRelation.clone() + loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + loadedConnectRelation + } + }.getOrElse(loaded) } case _ => None } @@ -1969,7 +2023,19 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor throw QueryCompilationErrors.groupByPositionRefersToAggregateFunctionError( index, ordinalExpr) } else { - ordinalExpr + trimAliases(ordinalExpr) match { + // HACK ALERT: If the ordinal expression is also an integer literal, don't use it + // but still keep the ordinal literal. The reason is we may repeatedly + // analyze the plan. Using a different integer literal may lead to + // a repeat GROUP BY ordinal resolution which is wrong. GROUP BY + // constant is meaningless so whatever value does not matter here. + // TODO: (SPARK-45932) GROUP BY ordinal should pull out grouping expressions to + // a Project, then the resolved ordinal expression is always + // `AttributeReference`. + case Literal(_: Int, IntegerType) => + Literal(index) + case _ => ordinalExpr + } } } else { throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size) @@ -2727,28 +2793,36 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + // We must wait until all expressions except for generator functions are resolved before + // rewriting generator functions in Project/Aggregate. This is necessary to make this rule + // stable for different execution orders of analyzer rules. See also SPARK-47241. + private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean = { + namedExprs.forall { ne => + ne.resolved || { + trimNonTopLevelAliases(ne) match { + case AliasedGenerator(_, _, _) => true + case _ => false + } + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(GENERATOR), ruleId) { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) - case Project(projectList, _) if projectList.count(hasGenerator) > 1 => - val generators = projectList.filter(hasGenerator).map(trimAlias) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT") - case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) => val nestedGenerator = aggList.find(hasNestedGenerator).get throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 => val generators = aggList.filter(hasGenerator).map(trimAlias) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate") + throw QueryCompilationErrors.moreThanOneGeneratorError(generators) - case agg @ Aggregate(groupList, aggList, child) if aggList.forall { - case AliasedGenerator(_, _, _) => true - case other => other.resolved - } && aggList.exists(hasGenerator) => + case Aggregate(groupList, aggList, child) if canRewriteGenerator(aggList) && + aggList.exists(hasGenerator) => // If generator in the aggregate list was visited, set the boolean flag true. var generatorVisited = false @@ -2793,16 +2867,16 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // first for replacing `Project` with `Aggregate`. p - case p @ Project(projectList, child) => + case p @ Project(projectList, child) if canRewriteGenerator(projectList) && + projectList.exists(hasGenerator) => val (resolvedGenerator, newProjectList) = projectList .map(trimNonTopLevelAliases) .foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { (res, e) => e match { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => - // It's a sanity check, this should not happen as the previous case will throw - // exception earlier. - assert(res._1.isEmpty, "More than one generator found in SELECT.") - + // If there are more than one generator, we only rewrite the first one and wait for + // the next analyzer iteration to rewrite the next one. + case AliasedGenerator(generator, names, outer) if res._1.isEmpty && + generator.childrenResolved => val g = Generate( generator, unrequiredChildIndex = Nil, @@ -2810,7 +2884,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) - (Some(g), res._2 ++ g.nullableOutput) case other => (res._1, res._2 :+ other) @@ -2830,6 +2903,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction => u + case p: Project => p + + case a: Aggregate => a + case p if p.expressions.exists(hasGenerator) => throw QueryCompilationErrors.generatorOutsideSelectError(p) } @@ -3796,9 +3873,9 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) - case CollectMetrics(name, metrics, child) => + case CollectMetrics(name, metrics, child, dataframeId) => val cleanedMetrics = metrics.map(trimNonTopLevelAliases) - CollectMetrics(name, cleanedMetrics, child) + CollectMetrics(name, cleanedMetrics, child, dataframeId) case Unpivot(ids, values, aliases, variableColumnName, valueColumnNames, child) => val cleanedIds = ids.map(_.map(trimNonTopLevelAliases)) @@ -3831,7 +3908,7 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { object EliminateEventTimeWatermark extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( _.containsPattern(EVENT_TIME_WATERMARK)) { - case EventTimeWatermark(_, _, child) if !child.isStreaming => child + case EventTimeWatermark(_, _, child) if child.resolved && !child.isStreaming => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 954f5f19cd3ec..7321f5becdc48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -263,7 +263,7 @@ object CTESubstitution extends Rule[LogicalPlan] { d.child } else { // Add a `SubqueryAlias` for hint-resolving rules to match relation names. - SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output)) + SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming)) } }.getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 43546bcaa421a..bb399e41d7d01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -64,12 +64,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB messageParameters = messageParameters) } - protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { - exprs.flatMap(_.collect { - case e: Generator => e - }).length > 1 - } - protected def hasMapType(dt: DataType): Boolean = { dt.existsRecursively(_.isInstanceOf[MapType]) } @@ -147,17 +141,56 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass, missingCol, orderedCandidates, a.origin) } + private def checkUnreferencedCTERelations( + cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], + visited: mutable.Map[Long, Boolean], + danglingCTERelations: mutable.ArrayBuffer[CTERelationDef], + cteId: Long): Unit = { + if (visited(cteId)) { + return + } + val (cteDef, _, refMap) = cteMap(cteId) + refMap.foreach { case (id, _) => + checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id) + } + danglingCTERelations.append(cteDef) + visited(cteId) = true + } + + /** + * Checks whether the operator allows non-deterministic expressions. + */ + private def operatorAllowsNonDeterministicExpressions(plan: LogicalPlan): Boolean = { + plan match { + case p: SupportsNonDeterministicExpression => + p.allowNonDeterministicExpression + case _ => false + } + } + def checkAnalysis(plan: LogicalPlan): Unit = { val inlineCTE = InlineCTE(alwaysInline = true) val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] inlineCTE.buildCTEMap(plan, cteMap) - cteMap.values.foreach { case (relation, refCount, _) => - // If a CTE relation is never used, it will disappear after inline. Here we explicitly check - // analysis for it, to make sure the entire query plan is valid. - if (refCount == 0) checkAnalysis0(relation.child) + val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef] + val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) + // If a CTE relation is never used, it will disappear after inline. Here we explicitly collect + // these dangling CTE relations, and put them back in the main query, to make sure the entire + // query plan is valid. + cteMap.foreach { case (cteId, (_, refCount, _)) => + // If a CTE relation ref count is 0, the other CTE relations that reference it should also be + // collected. This code will also guarantee the leaf relations that do not reference + // any others are collected first. + if (refCount == 0) { + checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId) + } } // Inline all CTEs in the plan to help check query plan structures in subqueries. - checkAnalysis0(inlineCTE(plan)) + var inlinedPlan: LogicalPlan = inlineCTE(plan) + if (danglingCTERelations.nonEmpty) { + inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq) + } + checkAnalysis0(inlinedPlan) plan.setAnalyzed() } @@ -365,6 +398,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB }) operator match { + case RelationTimeTravel(u: UnresolvedRelation, _, _) => + u.tableNotFound(u.multipartIdentifier) + case etw: EventTimeWatermark => etw.eventTime.dataType match { case s: StructType @@ -377,6 +413,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "eventName" -> toSQLId(etw.eventTime.name), "eventType" -> toSQLType(etw.eventTime.dataType))) } + case f: Filter if f.condition.dataType != BooleanType => f.failAnalysis( errorClass = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", @@ -484,7 +521,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB groupingExprs.foreach(checkValidGroupingExprs) aggregateExprs.foreach(checkValidAggregateExpression) - case CollectMetrics(name, metrics, _) => + case CollectMetrics(name, metrics, _, _) => if (name == null || name.isEmpty) { operator.failAnalysis( errorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME", @@ -683,10 +720,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB )) } - case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => - val generators = exprs.filter(expr => expr.exists(_.isInstanceOf[Generator])) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT") - case p @ Project(projectList, _) => projectList.foreach(_.transformDownWithPruning( _.containsPattern(UNRESOLVED_WINDOW_EXPRESSION)) { @@ -749,6 +782,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "dataType" -> toSQLType(mapCol.dataType))) case o if o.expressions.exists(!_.deterministic) && + !operatorAllowsNonDeterministicExpressions(o) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] && !o.isInstanceOf[Expand] && @@ -1075,17 +1109,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB * are allowed (e.g. self-joins). */ private def checkCollectedMetrics(plan: LogicalPlan): Unit = { - val metricsMap = mutable.Map.empty[String, LogicalPlan] + val metricsMap = mutable.Map.empty[String, CollectMetrics] def check(plan: LogicalPlan): Unit = plan.foreach { node => node match { - case metrics @ CollectMetrics(name, _, _) => - val simplifiedMetrics = simplifyPlanForCollectedMetrics(metrics.canonicalized) + case metrics @ CollectMetrics(name, _, _, dataframeId) => metricsMap.get(name) match { case Some(other) => - val simplifiedOther = simplifyPlanForCollectedMetrics(other.canonicalized) // Exact duplicates are allowed. They can be the result // of a CTE that is used multiple times or a self join. - if (simplifiedMetrics != simplifiedOther) { + if (dataframeId != other.dataframeId) { failAnalysis( errorClass = "DUPLICATED_METRICS_NAME", messageParameters = Map("metricName" -> name)) @@ -1104,30 +1136,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB check(plan) } - /** - * This method is only used for checking collected metrics. This method tries to - * remove extra project which only re-assign expr ids from the plan so that we can identify exact - * duplicates metric definition. - */ - private def simplifyPlanForCollectedMetrics(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperators { - case p: Project if p.projectList.size == p.child.output.size => - val assignExprIdOnly = p.projectList.zipWithIndex.forall { - case (Alias(attr: AttributeReference, _), index) => - // The input plan of this method is already canonicalized. The attribute id becomes the - // ordinal of this attribute in the child outputs. So an alias-only Project means the - // the id of the aliased attribute is the same as its index in the project list. - attr.exprId.id == index - case _ => false - } - if (assignExprIdOnly) { - p.child - } else { - p - } - } - } - /** * Validates to make sure the outer references appearing inside the subquery * are allowed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 98cbdea72d53b..c48006286be9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -29,10 +29,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} import org.apache.spark.sql.internal.SQLConf -trait ColumnResolutionHelper extends Logging { +trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { def conf: SQLConf @@ -337,7 +337,7 @@ trait ColumnResolutionHelper extends Logging { throws: Boolean = false, allowOuter: Boolean = false): Expression = { resolveExpression( - expr, + tryResolveColumnByPlanId(expr, plan), resolveColumnByName = nameParts => { plan.resolve(nameParts, conf.resolver) }, @@ -358,21 +358,8 @@ trait ColumnResolutionHelper extends Logging { e: Expression, q: LogicalPlan, allowOuter: Boolean = false): Expression = { - val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { - // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and - // expression are from Spark Connect, and need to be resolved in this way: - // 1, extract the attached plan id from the expression (UnresolvedAttribute only for now); - // 2, top-down traverse the query plan to find the plan node that matches the plan id; - // 3, if can not find the matching node, fail the analysis due to illegal references; - // 4, resolve the expression with the matching node, if any error occurs here, apply the - // old code path; - resolveExpressionByPlanId(e, q) - } else { - e - } - resolveExpression( - newE, + tryResolveColumnByPlanId(e, q), resolveColumnByName = nameParts => { q.resolveChildren(nameParts, conf.resolver) }, @@ -392,39 +379,46 @@ trait ColumnResolutionHelper extends Logging { } } - private def resolveExpressionByPlanId( + // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and + // expression are from Spark Connect, and need to be resolved in this way: + // 1. extract the attached plan id from UnresolvedAttribute; + // 2. top-down traverse the query plan to find the plan node that matches the plan id; + // 3. if can not find the matching node, fail the analysis due to illegal references; + // 4. if more than one matching nodes are found, fail due to ambiguous column reference; + // 5. resolve the expression with the matching node, if any error occurs here, return the + // original expression as it is. + private def tryResolveColumnByPlanId( e: Expression, - q: LogicalPlan): Expression = { - if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { - return e - } - - e match { - case u: UnresolvedAttribute => - resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u) - case _ => - e.mapChildren(c => resolveExpressionByPlanId(c, q)) - } + q: LogicalPlan, + idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): Expression = e match { + case u: UnresolvedAttribute => + resolveUnresolvedAttributeByPlanId( + u, q, idToPlan: mutable.HashMap[Long, LogicalPlan] + ).getOrElse(u) + case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) => + e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan)) + case _ => e } private def resolveUnresolvedAttributeByPlanId( u: UnresolvedAttribute, - q: LogicalPlan): Option[NamedExpression] = { + q: LogicalPlan, + idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = { val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) if (planIdOpt.isEmpty) return None val planId = planIdOpt.get logDebug(s"Extract plan_id $planId from $u") - val planOpt = q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId)) - if (planOpt.isEmpty) { - // For example: - // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) - // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) - // df1.select(df2.a) <- illegal reference df2.a - throw new AnalysisException(s"When resolving $u, " + - s"fail to find subplan with plan_id=$planId in $q") - } - val plan = planOpt.get + val plan = idToPlan.getOrElseUpdate(planId, { + findPlanById(u, planId, q).getOrElse { + // For example: + // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) + // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) + // df1.select(df2.a) <- illegal reference df2.a + throw new AnalysisException(s"When resolving $u, " + + s"fail to find subplan with plan_id=$planId in $q") + } + }) try { plan.resolve(u.nameParts, conf.resolver) @@ -434,4 +428,28 @@ trait ColumnResolutionHelper extends Logging { None } } + + private def findPlanById( + u: UnresolvedAttribute, + id: Long, + plan: LogicalPlan): Option[LogicalPlan] = { + if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { + Some(plan) + } else if (plan.children.length == 1) { + findPlanById(u, id, plan.children.head) + } else if (plan.children.length > 1) { + val matched = plan.children.flatMap(findPlanById(u, id, _)) + if (matched.length > 1) { + throw new AnalysisException( + errorClass = "AMBIGUOUS_COLUMN_REFERENCE", + messageParameters = Map("name" -> toSQLId(u.nameParts)), + origin = u.origin + ) + } else { + matched.headOption + } + } else { + None + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 09cf61a77955a..f51127f53b382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -83,7 +83,7 @@ object DecimalPrecision extends TypeCoercionRule { val resultType = widerDecimalType(p1, s1, p2, s2) val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType) val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType) - b.makeCopy(Array(newE1, newE2)) + b.withNewChildren(Seq(newE1, newE2)) } /** @@ -202,21 +202,21 @@ object DecimalPrecision extends TypeCoercionRule { case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && l.dataType.isInstanceOf[IntegralType] && literalPickMinimumPrecision => - b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r)) + b.withNewChildren(Seq(Cast(l, DataTypeUtils.fromLiteral(l)), r)) case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && r.dataType.isInstanceOf[IntegralType] && literalPickMinimumPrecision => - b.makeCopy(Array(l, Cast(r, DataTypeUtils.fromLiteral(r)))) + b.withNewChildren(Seq(l, Cast(r, DataTypeUtils.fromLiteral(r)))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case (l @ IntegralTypeExpression(), r @ DecimalExpression(_, _)) => - b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + b.withNewChildren(Seq(Cast(l, DecimalType.forType(l.dataType)), r)) case (l @ DecimalExpression(_, _), r @ IntegralTypeExpression()) => - b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + b.withNewChildren(Seq(l, Cast(r, DecimalType.forType(r.dataType)))) case (l, r @ DecimalExpression(_, _)) if isFloat(l.dataType) => - b.makeCopy(Array(l, Cast(r, DoubleType))) + b.withNewChildren(Seq(l, Cast(r, DoubleType))) case (l @ DecimalExpression(_, _), r) if isFloat(r.dataType) => - b.makeCopy(Array(Cast(l, DoubleType), r)) + b.withNewChildren(Seq(Cast(l, DoubleType), r)) case _ => b } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 558579cdb80ac..aaf718fab941d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -930,7 +930,14 @@ object FunctionRegistry { since: Option[String] = None): (String, (ExpressionInfo, FunctionBuilder)) = { val info = FunctionRegistryBase.expressionInfo[T](name, since) val funcBuilder = (expressions: Seq[Expression]) => { - assert(expressions.forall(_.resolved), "function arguments must be resolved.") + val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) + if (lambdas.nonEmpty && !builder.supportsLambda) { + throw new AnalysisException( + errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + messageParameters = Map( + "class" -> builder.getClass.getCanonicalName)) + } + assert(others.forall(_.resolved), "function arguments must be resolved.") val rearrangedExpressions = rearrangeExpressions(name, builder, expressions) val expr = builder.build(name, rearrangedExpressions) if (setAlias) expr.setTagValue(FUNC_ALIAS, name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala new file mode 100644 index 0000000000000..0f9b93cc2986d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS +import org.apache.spark.sql.connector.catalog.CatalogManager + +/** + * A rule that rewrites DataFrameDropColumns to Project. + * Note that DataFrameDropColumns allows and ignores non-existing columns. + */ +class ResolveDataFrameDropColumns(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with ColumnResolutionHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(DF_DROP_COLUMNS)) { + case d: DataFrameDropColumns if d.childrenResolved => + // expressions in dropList can be unresolved, e.g. + // df.drop(col("non-existing-column")) + val dropped = d.dropList.map { + case u: UnresolvedAttribute => + resolveExpressionByPlanChildren(u, d) + case e => e + } + val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr))) + if (remaining.size == d.child.output.size) { + d.child + } else { + Project(remaining, d.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index e0d3e5629ef8b..9031b33a84a08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -20,19 +20,24 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER import org.apache.spark.sql.types.StringType /** * Resolves the identifier expressions and builds the original plans/expressions. */ -object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper { +class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch]) + extends Rule[LogicalPlan] with AliasHelper with EvalHelper { + + private val executor = new RuleExecutor[LogicalPlan] { + override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]] + } override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { - case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved => - p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr)) + case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved => + executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)) case other => other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 760ea466b8579..73600f5c70649 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -17,28 +17,29 @@ package org.apache.spark.sql.catalyst.analysis -import scala.util.control.NonFatal - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} +import org.apache.spark.sql.catalyst.optimizer.EvalInlineTables +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess +import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLExpr, toSQLId} import org.apache.spark.sql.types.{StructField, StructType} /** - * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. + * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[ResolvedInlineTable]]. */ object ResolveInlineTables extends Rule[LogicalPlan] with CastSupport with AliasHelper with EvalHelper { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( - AlwaysProcess.fn, ruleId) { - case table: UnresolvedInlineTable if table.expressionsResolved => - validateInputDimension(table) - validateInputEvaluable(table) - convert(table) + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) { + case table: UnresolvedInlineTable if table.expressionsResolved => + validateInputDimension(table) + validateInputEvaluable(table) + val resolvedTable = findCommonTypesAndCast(table) + earlyEvalIfPossible(resolvedTable) + } } /** @@ -74,7 +75,10 @@ object ResolveInlineTables extends Rule[LogicalPlan] table.rows.foreach { row => row.foreach { e => // Note that nondeterministic expressions are not supported since they are not foldable. - if (!e.resolved || !trimAliases(prepareForEval(e)).foldable) { + // Only exception are CURRENT_LIKE expressions, which are replaced by a literal + // In later stages. + if ((!e.resolved && !e.containsPattern(CURRENT_LIKE)) + || !trimAliases(prepareForEval(e)).foldable) { e.failAnalysis( errorClass = "INVALID_INLINE_TABLE.CANNOT_EVALUATE_EXPRESSION_IN_INLINE_TABLE", messageParameters = Map("expr" -> toSQLExpr(e))) @@ -84,14 +88,12 @@ object ResolveInlineTables extends Rule[LogicalPlan] } /** - * Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]] - * into a [[LocalRelation]]. - * * This function attempts to coerce inputs into consistent types. * * This is package visible for unit testing. */ - private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = { + private[analysis] def findCommonTypesAndCast(table: UnresolvedInlineTable): + ResolvedInlineTable = { // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) @@ -105,26 +107,30 @@ object ResolveInlineTables extends Rule[LogicalPlan] val attributes = DataTypeUtils.toAttributes(StructType(fields)) assert(fields.size == table.names.size) - val newRows: Seq[InternalRow] = table.rows.map { row => - InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => - val targetType = fields(ci).dataType - try { + val castedRows: Seq[Seq[Expression]] = table.rows.map { row => + row.zipWithIndex.map { + case (e, ci) => + val targetType = fields(ci).dataType val castedExpr = if (DataTypeUtils.sameType(e.dataType, targetType)) { e } else { cast(e, targetType) } - prepareForEval(castedExpr).eval() - } catch { - case NonFatal(ex) => - table.failAnalysis( - errorClass = "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", - messageParameters = Map("sqlExpr" -> toSQLExpr(e)), - cause = ex) - } - }) + castedExpr + } } - LocalRelation(attributes, newRows) + ResolvedInlineTable(castedRows, attributes) + } + + /** + * This function attempts to early evaluate rows in inline table. + * If evaluation doesn't rely on non-deterministic expressions (e.g. current_like) + * expressions will be evaluated and inlined as [[LocalRelation]] + * This is package visible for unit testing. + */ + private[analysis] def earlyEvalIfPossible(table: ResolvedInlineTable): LogicalPlan = { + val earlyEvalPossible = table.rows.flatten.forall(!_.containsPattern(CURRENT_LIKE)) + if (earlyEvalPossible) EvalInlineTables(table) else table } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index 09ae87b071fdd..a03d5438ff6aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, IntegerLiteral, Literal, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} @@ -134,7 +134,19 @@ object ResolveReferencesInAggregate extends SQLConfHelper groupExprs } else { // This is a valid GROUP BY ALL aggregate. - expandedGroupExprs.get + expandedGroupExprs.get.zipWithIndex.map { case (expr, index) => + trimAliases(expr) match { + // HACK ALERT: If the expanded grouping expression is an integer literal, don't use it + // but use an integer literal of the index. The reason is we may repeatedly + // analyze the plan, and the original integer literal may cause failures + // with a later GROUP BY ordinal resolution. GROUP BY constant is + // meaningless so whatever value does not matter here. + case IntegerLiteral(_) => + // GROUP BY ordinal uses 1-based index. + Literal(index + 1) + case _ => expr + } + } } } else { groupExprs diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala index 1ee218f9369c5..d1b43283e74b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala @@ -87,84 +87,84 @@ object TimeWindowing extends Rule[LogicalPlan] { val window = windowExpressions.head if (StructType.acceptsType(window.timeColumn.dataType)) { - return p.transformExpressions { + p.transformExpressions { case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn)) } - } - - val metadata = window.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } - - val newMetadata = new MetadataBuilder() - .withMetadata(metadata) - .putBoolean(TimeWindow.marker, true) - .build() + } else { + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } - def getWindow(i: Int, dataType: DataType): Expression = { - val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) - val remainder = (timestamp - window.startTime) % window.slideDuration - val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0), - remainder + window.slideDuration)), Some(remainder)) - val windowStart = lastStart - i * window.slideDuration - val windowEnd = windowStart + window.windowDuration + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(TimeWindow.marker, true) + .build() - // We make sure value fields are nullable since the dataType of TimeWindow defines them - // as nullable. - CreateNamedStruct( - Literal(WINDOW_START) :: - PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: - Literal(WINDOW_END) :: - PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: - Nil) - } + def getWindow(i: Int, dataType: DataType): Expression = { + val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) + val remainder = (timestamp - window.startTime) % window.slideDuration + val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0), + remainder + window.slideDuration)), Some(remainder)) + val windowStart = lastStart - i * window.slideDuration + val windowEnd = windowStart + window.windowDuration + + // We make sure value fields are nullable since the dataType of TimeWindow defines them + // as nullable. + CreateNamedStruct( + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: + Nil) + } - val windowAttr = AttributeReference( - WINDOW_COL_NAME, window.dataType, metadata = newMetadata)() + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = newMetadata)() - if (window.windowDuration == window.slideDuration) { - val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( - exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata)) + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( + exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata)) - val replacedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } - // For backwards compatibility we add a filter to filter out nulls - val filterExpr = IsNotNull(window.timeColumn) + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) - replacedPlan.withNewChildren( - Project(windowStruct +: child.output, - Filter(filterExpr, child)) :: Nil) - } else { - val overlappingWindows = - math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = - Seq.tabulate(overlappingWindows)(i => - getWindow(i, window.timeColumn.dataType)) - - val projections = windows.map(_ +: child.output) - - // When the condition windowDuration % slideDuration = 0 is fulfilled, - // the estimation of the number of windows becomes exact one, - // which means all produced windows are valid. - val filterExpr = - if (window.windowDuration % window.slideDuration == 0) { - IsNotNull(window.timeColumn) + replacedPlan.withNewChildren( + Project(windowStruct +: child.output, + Filter(filterExpr, child)) :: Nil) } else { - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => + getWindow(i, window.timeColumn.dataType)) + + val projections = windows.map(_ +: child.output) + + // When the condition windowDuration % slideDuration = 0 is fulfilled, + // the estimation of the number of windows becomes exact one, + // which means all produced windows are valid. + val filterExpr = + if (window.windowDuration % window.slideDuration == 0) { + IsNotNull(window.timeColumn) + } else { + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + } + + val substitutedPlan = Filter(filterExpr, + Expand(projections, windowAttr +: child.output, child)) + + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + renamedPlan.withNewChildren(substitutedPlan :: Nil) } - - val substitutedPlan = Filter(filterExpr, - Expand(projections, windowAttr +: child.output, child)) - - val renamedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } - - renamedPlan.withNewChildren(substitutedPlan :: Nil) } } else if (numWindowExpr > 1) { throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) @@ -209,71 +209,71 @@ object SessionWindowing extends Rule[LogicalPlan] { val session = sessionExpressions.head if (StructType.acceptsType(session.timeColumn.dataType)) { - return p transformExpressions { + p transformExpressions { case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn)) } - } + } else { + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } - val metadata = session.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() - val newMetadata = new MetadataBuilder() - .withMetadata(metadata) - .putBoolean(SessionWindow.marker, true) - .build() - - val sessionAttr = AttributeReference( - SESSION_COL_NAME, session.dataType, metadata = newMetadata)() - - val sessionStart = - PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) - val gapDuration = session.gapDuration match { - case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => - Cast(expr, CalendarIntervalType) - case other => - throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) - } - val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, - session.timeColumn.dataType, LongType) - - // We make sure value fields are nullable since the dataType of SessionWindow defines them - // as nullable. - val literalSessionStruct = CreateNamedStruct( - Literal(SESSION_START) :: - PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) - .castNullable() :: - Literal(SESSION_END) :: - PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) - .castNullable() :: - Nil) - - val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( - exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() - val replacedPlan = p transformExpressions { - case s: SessionWindow => sessionAttr - } + val sessionStart = + PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) + val gapDuration = session.gapDuration match { + case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => + Cast(expr, CalendarIntervalType) + case other => + throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) + } + val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, + session.timeColumn.dataType, LongType) - val filterByTimeRange = session.gapDuration match { - case Literal(interval: CalendarInterval, CalendarIntervalType) => - interval == null || interval.months + interval.days + interval.microseconds <= 0 - case _ => true - } + // We make sure value fields are nullable since the dataType of SessionWindow defines them + // as nullable. + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) + .castNullable() :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) + .castNullable() :: + Nil) - // As same as tumbling window, we add a filter to filter out nulls. - // And we also filter out events with negative or zero or invalid gap duration. - val filterExpr = if (filterByTimeRange) { - IsNotNull(session.timeColumn) && - (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) - } else { - IsNotNull(session.timeColumn) - } + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) - replacedPlan.withNewChildren( - Filter(filterExpr, - Project(sessionStruct +: child.output, child)) :: Nil) + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + val filterByTimeRange = session.gapDuration match { + case Literal(interval: CalendarInterval, CalendarIntervalType) => + interval == null || interval.months + interval.days + interval.microseconds <= 0 + case _ => true + } + + // As same as tumbling window, we add a filter to filter out nulls. + // And we also filter out events with negative or zero or invalid gap duration. + val filterExpr = if (filterByTimeRange) { + IsNotNull(session.timeColumn) && + (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) + } else { + IsNotNull(session.timeColumn) + } + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) + } } else if (numWindowExpr > 1) { throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala index 78b776f12f074..f1077378b2d9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala @@ -51,7 +51,7 @@ object ResolveWithCTE extends Rule[LogicalPlan] { case ref: CTERelationRef if !ref.resolved => cteDefMap.get(ref.cteId).map { cteDef => - CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output) + CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming) }.getOrElse { ref } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 42abc0eafda7a..fabb5634ad10c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -371,8 +371,16 @@ object TableOutputResolver { resolveColumnsByPosition(tableName, Seq(param), Seq(fakeAttr), conf, addError, colPath) } if (res.length == 1) { - val func = LambdaFunction(res.head, Seq(param)) - Some(Alias(ArrayTransform(nullCheckedInput, func), expected.name)()) + if (res.head == param) { + // If the element type is the same, we can reuse the input array directly. + Some( + Alias(nullCheckedInput, expected.name)( + nonInheritableMetadataKeys = + Seq(CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))) + } else { + val func = LambdaFunction(res.head, Seq(param)) + Some(Alias(ArrayTransform(nullCheckedInput, func), expected.name)()) + } } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 190e72a8e669e..c9a4a2d40246a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1102,22 +1102,22 @@ object TypeCoercion extends TypeCoercionBase { case a @ BinaryArithmetic(left @ StringTypeExpression(), right) if right.dataType != CalendarIntervalType => - a.makeCopy(Array(Cast(left, DoubleType), right)) + a.withNewChildren(Seq(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringTypeExpression()) if left.dataType != CalendarIntervalType => - a.makeCopy(Array(left, Cast(right, DoubleType))) + a.withNewChildren(Seq(left, Cast(right, DoubleType))) // For equality between string and timestamp we cast the string to a timestamp // so that things like rounding of subsecond precision does not affect the comparison. case p @ Equality(left @ StringTypeExpression(), right @ TimestampTypeExpression()) => - p.makeCopy(Array(Cast(left, TimestampType), right)) + p.withNewChildren(Seq(Cast(left, TimestampType), right)) case p @ Equality(left @ TimestampTypeExpression(), right @ StringTypeExpression()) => - p.makeCopy(Array(left, Cast(right, TimestampType))) + p.withNewChildren(Seq(left, Cast(right, TimestampType))) case p @ BinaryComparison(left, right) if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get - p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) + p.withNewChildren(Seq(castExpr(left, commonType), castExpr(right, commonType))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index 2e3cabce24a4b..20cc9a2ab5450 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -136,7 +136,7 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { args(posToIndex(pos)) } - case _ => plan + case other => other } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 1c72ec0d69980..81d92acc6e84a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Unary import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.connector.catalog.TableWritePrivilege import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Metadata, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -55,9 +56,19 @@ trait UnresolvedUnaryNode extends UnaryNode with UnresolvedNode */ case class PlanWithUnresolvedIdentifier( identifierExpr: Expression, - planBuilder: Seq[String] => LogicalPlan) - extends UnresolvedLeafNode { + children: Seq[LogicalPlan], + planBuilder: (Seq[String], Seq[LogicalPlan]) => LogicalPlan) + extends UnresolvedNode { + + def this(identifierExpr: Expression, planBuilder: Seq[String] => LogicalPlan) = { + this(identifierExpr, Nil, (ident, _) => planBuilder(ident)) + } + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = + copy(identifierExpr, newChildren, planBuilder) } /** @@ -96,20 +107,45 @@ case class UnresolvedRelation( override def name: String = tableName + def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = { + if (privileges.nonEmpty) { + val newOptions = new java.util.HashMap[String, String] + newOptions.putAll(options) + newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(",")) + copy(options = new CaseInsensitiveStringMap(newOptions)) + } else { + this + } + } + + def clearWritePrivileges: UnresolvedRelation = { + if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) { + val newOptions = new java.util.HashMap[String, String] + newOptions.putAll(options) + newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES) + copy(options = new CaseInsensitiveStringMap(newOptions)) + } else { + this + } + } + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION) } object UnresolvedRelation { + // An internal option of `UnresolvedRelation` to specify the required write privileges when + // writing data to this relation. + val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__" + def apply( tableIdentifier: TableIdentifier, extraOptions: CaseInsensitiveStringMap, isStreaming: Boolean): UnresolvedRelation = { - UnresolvedRelation( - tableIdentifier.database.toSeq :+ tableIdentifier.table, extraOptions, isStreaming) + UnresolvedRelation(tableIdentifier.nameParts, extraOptions, isStreaming) } def apply(tableIdentifier: TableIdentifier): UnresolvedRelation = - UnresolvedRelation(tableIdentifier.database.toSeq :+ tableIdentifier.table) + UnresolvedRelation(tableIdentifier.nameParts) } /** @@ -127,6 +163,21 @@ case class UnresolvedInlineTable( lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved)) } +/** + * An resolved inline table that holds all the expressions that were checked for + * the right shape and common data types. + * This is a preparation step for [[org.apache.spark.sql.catalyst.optimizer.EvalInlineTables]] which + * will produce a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]] + * for this inline table. + * + * @param output list of column attributes + * @param rows expressions for the data rows + */ +case class ResolvedInlineTable(rows: Seq[Seq[Expression]], output: Seq[Attribute]) + extends LeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(INLINE_TABLE_EVAL) +} + /** * A table-valued function, e.g. * {{{ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 392c911ddb8e0..0de9673a5f968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Subque import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils} import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -227,7 +228,8 @@ class SessionCatalog( /** This method discards any cached table relation plans for the given table identifier. */ def invalidateCachedTable(name: TableIdentifier): Unit = { val qualified = qualifyIdentifier(name) - invalidateCachedTable(QualifiedTableName(qualified.database.get, qualified.table)) + invalidateCachedTable(QualifiedTableName( + qualified.catalog.get, qualified.database.get, qualified.table)) } /** This method provides a way to invalidate all the cached plans. */ @@ -295,7 +297,7 @@ class SessionCatalog( } if (cascade && databaseExists(dbName)) { listTables(dbName).foreach { t => - invalidateCachedTable(QualifiedTableName(dbName, t.table)) + invalidateCachedTable(QualifiedTableName(SESSION_CATALOG_NAME, dbName, t.table)) } } externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade) @@ -330,12 +332,16 @@ class SessionCatalog( def getCurrentDatabase: String = synchronized { currentDb } def setCurrentDatabase(db: String): Unit = { + setCurrentDatabaseWithNameCheck(db, requireDbExists) + } + + def setCurrentDatabaseWithNameCheck(db: String, nameCheck: String => Unit): Unit = { val dbName = format(db) if (dbName == globalTempViewManager.database) { throw QueryCompilationErrors.cannotUsePreservedDatabaseAsCurrentDatabaseError( globalTempViewManager.database) } - requireDbExists(dbName) + nameCheck(dbName) synchronized { currentDb = dbName } } @@ -1126,7 +1132,8 @@ class SessionCatalog( def refreshTable(name: TableIdentifier): Unit = synchronized { getLocalOrGlobalTempView(name).map(_.refresh).getOrElse { val qualifiedIdent = qualifyIdentifier(name) - val qualifiedTableName = QualifiedTableName(qualifiedIdent.database.get, qualifiedIdent.table) + val qualifiedTableName = QualifiedTableName( + qualifiedIdent.catalog.get, qualifiedIdent.database.get, qualifiedIdent.table) tableRelationCache.invalidate(qualifiedTableName) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 51586a0065e95..2c27da3cf6e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -66,6 +66,8 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private val LENIENT_TS_FORMATTER_SUPPORTED_DATE_FORMATS = Set( "yyyy-MM-dd", "yyyy-M-d", "yyyy-M-dd", "yyyy-MM-d", "yyyy-MM", "yyyy-M", "yyyy") + private val isDefaultNTZ = SQLConf.get.timestampType == TimestampNTZType + /** * Similar to the JSON schema inference * 1. Infer type of each row @@ -199,11 +201,12 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { } private def tryParseTimestampNTZ(field: String): DataType = { - // We can only parse the value as TimestampNTZType if it does not have zone-offset or - // time-zone component and can be parsed with the timestamp formatter. - // Otherwise, it is likely to be a timestamp with timezone. - if (timestampNTZFormatter.parseWithoutTimeZoneOptional(field, false).isDefined) { - SQLConf.get.timestampType + // For text-based format, it's ambiguous to infer a timestamp string without timezone, as it can + // be both TIMESTAMP LTZ and NTZ. To avoid behavior changes with the new support of NTZ, here + // we only try to infer NTZ if the config is set to use NTZ by default. + if (isDefaultNTZ && + timestampNTZFormatter.parseWithoutTimeZoneOptional(field, false).isDefined) { + TimestampNTZType } else { tryParseTimestamp(field) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 845c815c5648b..c5a6bf5076dec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -277,6 +277,15 @@ class CSVOptions( val unescapedQuoteHandling: UnescapedQuoteHandling = UnescapedQuoteHandling.valueOf(parameters .getOrElse(UNESCAPED_QUOTE_HANDLING, "STOP_AT_DELIMITER").toUpperCase(Locale.ROOT)) + /** + * The column pruning feature can be enabled either via the CSV option `columnPruning` or + * in non-multiline mode via initialization of CSV options by the SQL config: + * `spark.sql.csv.parser.columnPruning.enabled`. + * The feature is disabled in the `multiLine` mode because of the issue: + * https://github.com/uniVocity/univocity-parsers/issues/529 + */ + val isColumnPruningEnabled: Boolean = getBool(COLUMN_PRUNING, !multiLine && columnPruning) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -376,4 +385,5 @@ object CSVOptions extends DataSourceOptions { val SEP = "sep" val DELIMITER = "delimiter" newOption(SEP, DELIMITER) + val COLUMN_PRUNING = newOption("columnPruning") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index b99ee630d4b22..f0663ddd69b1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -72,7 +72,7 @@ class UnivocityParser( // positions. Generally assigned by input configuration options, except when input column(s) have // default values, in which case we omit the explicit indexes in order to know how many tokens // were present in each line instead. - private def columnPruning: Boolean = options.columnPruning && + private def columnPruning: Boolean = options.isColumnPruningEnabled && !requiredSchema.exists(_.metadata.contains(EXISTS_DEFAULT_COLUMN_METADATA_KEY)) // When column pruning is enabled, the parser only parses the required columns based on @@ -139,6 +139,7 @@ class UnivocityParser( // Retrieve the raw record string. private def getCurrentInput: UTF8String = { + if (tokenizer.getContext == null) return null val currentContent = tokenizer.getContext.currentParsedContent() if (currentContent == null) null else UTF8String.fromString(currentContent.stripLineEnd) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index ff72b5a0d9653..6dc89bb4d4b1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -73,8 +73,14 @@ object ExpressionEncoder { * Given a set of N encoders, constructs a new encoder that produce objects as items in an * N-tuple. Note that these encoders should be unresolved so that information about * name/positional binding is preserved. + * When `useNullSafeDeserializer` is true, the deserialization result for a child will be null if + * the input is null. It is false by default as most deserializers handle null input properly and + * don't require an extra null check. Some of them are null-tolerant, such as the deserializer for + * `Option[T]`, and we must not set it to true in this case. */ - def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + def tuple( + encoders: Seq[ExpressionEncoder[_]], + useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = { if (encoders.length > 22) { throw QueryExecutionErrors.elementsOfTupleExceedLimitError() } @@ -119,7 +125,7 @@ object ExpressionEncoder { case GetColumnByOrdinal(0, _) => input } - if (enc.objSerializer.nullable) { + if (useNullSafeDeserializer && enc.objSerializer.nullable) { nullSafe(input, childDeserializer) } else { childDeserializer diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 24ade61c12149..01509c5b968c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1615,7 +1615,8 @@ case class Cast( val block = inline"new java.math.BigDecimal($MICROS_PER_SECOND)" code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()" } - private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * (long)$MICROS_PER_SECOND" + private[this] def longToTimeStampCode(l: ExprValue): Block = + code"java.util.concurrent.TimeUnit.SECONDS.toMicros($l)" private[this] def timestampToLongCode(ts: ExprValue): Block = code"java.lang.Math.floorDiv($ts, $MICROS_PER_SECOND)" private[this] def timestampToDoubleCode(ts: ExprValue): Block = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1a84859cc3a15..5d8d428e27d68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -79,10 +79,6 @@ class EquivalentExpressions( case _ => if (useCount > 0) { map.put(wrapper, ExpressionStats(expr)(useCount)) - } else { - // Should not happen - throw new IllegalStateException( - s"Cannot update expression: $expr in map: $map with use count: $useCount") } false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c2330cdb59dbc..bd7369e57b057 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1410,4 +1410,6 @@ case class MultiCommutativeOp( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this.copy(operands = newChildren)(originalRoot) + + override protected final def otherCopyArgs: Seq[AnyRef] = originalRoot :: Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 133a39d987459..316cb9e0bbc34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -353,7 +353,7 @@ case class PandasStddev( override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(n === ddof, divideByZeroEvalResult, sqrt(m2 / (n - ddof)))) + If(n === ddof.toDouble, divideByZeroEvalResult, sqrt(m2 / (n - ddof.toDouble)))) } override def prettyName: String = "pandas_stddev" @@ -375,7 +375,7 @@ case class PandasVariance( override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(n === ddof, divideByZeroEvalResult, m2 / (n - ddof))) + If(n === ddof.toDouble, divideByZeroEvalResult, m2 / (n - ddof.toDouble))) } override def prettyName: String = "pandas_variance" @@ -405,8 +405,8 @@ case class PandasSkewness(child: Expression) val _m2 = If(abs(m2) < 1e-14, Literal(0.0), m2) val _m3 = If(abs(m3) < 1e-14, Literal(0.0), m3) - If(n < 3, Literal.create(null, DoubleType), - If(_m2 === 0.0, Literal(0.0), sqrt(n - 1) * (n / (n - 2)) * _m3 / sqrt(_m2 * _m2 * _m2))) + If(n < 3.0, Literal.create(null, DoubleType), + If(_m2 === 0.0, Literal(0.0), sqrt(n - 1.0) * (n / (n - 2.0)) * _m3 / sqrt(_m2 * _m2 * _m2))) } override protected def withNewChildInternal(newChild: Expression): PandasSkewness = @@ -423,9 +423,9 @@ case class PandasKurtosis(child: Expression) override protected def momentOrder = 4 override val evaluateExpression: Expression = { - val adj = ((n - 1) / (n - 2)) * ((n - 1) / (n - 3)) * 3 - val numerator = n * (n + 1) * (n - 1) * m4 - val denominator = (n - 2) * (n - 3) * m2 * m2 + val adj = ((n - 1.0) / (n - 2.0)) * ((n - 1.0) / (n - 3.0)) * 3.0 + val numerator = n * (n + 1.0) * (n - 1.0) * m4 + val denominator = (n - 2.0) * (n - 3.0) * m2 * m2 // floating point error // @@ -436,7 +436,7 @@ case class PandasKurtosis(child: Expression) val _numerator = If(abs(numerator) < 1e-14, Literal(0.0), numerator) val _denominator = If(abs(denominator) < 1e-14, Literal(0.0), denominator) - If(n < 4, Literal.create(null, DoubleType), + If(n < 4.0, Literal.create(null, DoubleType), If(_denominator === 0.0, Literal(0.0), _numerator / _denominator - adj)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index ff31fb1128b9b..b392b603ab8d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -156,7 +156,7 @@ case class PandasCovar( override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(n === ddof, divideByZeroEvalResult, ck / (n - ddof))) + If(n === ddof.toDouble, divideByZeroEvalResult, ck / (n - ddof.toDouble))) } override def prettyName: String = "pandas_covar" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index 40518982958cd..7d73cf211a6e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -271,8 +271,14 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues - override lazy val updateExpressions: Seq[Expression] = - covarPop.updateExpressions ++ varPop.updateExpressions + override lazy val updateExpressions: Seq[Expression] = { + // RegrSlope only handles paris where both y and x are non-empty, so we need additional + // judgment for calculating VariancePop. + val isNull = left.isNull || right.isNull + covarPop.updateExpressions ++ varPop.updateExpressions.zip(varPop.aggBufferAttributes).map { + case (newValue, oldValue) => If(isNull, oldValue, newValue) + } + } override lazy val mergeExpressions: Seq[Expression] = covarPop.mergeExpressions ++ varPop.mergeExpressions @@ -324,8 +330,14 @@ case class RegrIntercept(left: Expression, right: Expression) extends Declarativ override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues - override lazy val updateExpressions: Seq[Expression] = - covarPop.updateExpressions ++ varPop.updateExpressions + override lazy val updateExpressions: Seq[Expression] = { + // RegrIntercept only handles paris where both y and x are non-empty, so we need additional + // judgment for calculating VariancePop. + val isNull = left.isNull || right.isNull + covarPop.updateExpressions ++ varPop.updateExpressions.zip(varPop.aggBufferAttributes).map { + case (newValue, oldValue) => If(isNull, oldValue, newValue) + } + } override lazy val mergeExpressions: Seq[Expression] = covarPop.mergeExpressions ++ varPop.mergeExpressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 6061f625ef07b..183e5d6697e99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -229,7 +229,7 @@ case class BitwiseCount(child: Expression) override def prettyName: String = "bit_count" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.dataType match { - case BooleanType => defineCodeGen(ctx, ev, c => s"if ($c) 1 else 0") + case BooleanType => defineCodeGen(ctx, ev, c => s"($c) ? 1 : 0") case _ => defineCodeGen(ctx, ev, c => s"java.lang.Long.bitCount($c)") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fe9c4015c15ec..45896382af672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,6 +22,7 @@ import java.util.Comparator import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.SparkException.internalError import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -40,7 +41,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String} /** @@ -712,6 +712,7 @@ case class MapConcat(children: Seq[Expression]) } } + override def stateful: Boolean = true override def nullable: Boolean = children.exists(_.nullable) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) @@ -827,6 +828,8 @@ case class MapFromEntries(child: Expression) override def nullable: Boolean = child.nullable || nullEntries + override def stateful: Boolean = true + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { @@ -3080,6 +3083,34 @@ case class Sequence( } object Sequence { + private def prettyName: String = "sequence" + + def sequenceLength(start: Long, stop: Long, step: Long): Int = { + try { + val delta = Math.subtractExact(stop, start) + if (delta == Long.MinValue && step == -1L) { + // We must special-case division of Long.MinValue by -1 to catch potential unchecked + // overflow in next operation. Division does not have a builtin overflow check. We + // previously special-case div-by-zero. + throw new ArithmeticException("Long overflow (Long.MinValue / -1)") + } + val len = if (stop == start) 1L else Math.addExact(1L, (delta / step)) + if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(len) + } + len.toInt + } catch { + // We handle overflows in the previous try block by raising an appropriate exception. + case _: ArithmeticException => + val safeLen = + BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step) + if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(safeLen) + } + throw internalError("Unreachable code reached.") + case e: Exception => throw e + } + } private type LessThanOrEqualFn = (Any, Any) => Boolean @@ -3451,13 +3482,7 @@ object Sequence { || (estimatedStep == num.zero && start == stop), s"Illegal sequence boundaries: $start to $stop by $step") - val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong - - require( - len <= MAX_ROUNDED_ARRAY_LENGTH, - s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") - - len.toInt + sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong) } private def genSequenceLengthCode( @@ -3467,7 +3492,7 @@ object Sequence { step: String, estimatedStep: String, len: String): String = { - val longLen = ctx.freshName("longLen") + val calcFn = classOf[Sequence].getName + ".sequenceLength" s""" |if (!(($estimatedStep > 0 && $start <= $stop) || | ($estimatedStep < 0 && $start >= $stop) || @@ -3475,12 +3500,7 @@ object Sequence { | throw new IllegalArgumentException( | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); |} - |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep; - |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { - | throw new IllegalArgumentException( - | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); - |} - |int $len = (int) $longLen; + |int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep); """.stripMargin } } @@ -4711,7 +4731,6 @@ case class ArrayInsert( } case (e1, e2, e3) => Seq.empty } - Seq.empty } override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 2051219131219..1b6f86984be77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -242,6 +242,8 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def stateful: Boolean = true + override def eval(input: InternalRow): Any = { var i = 0 while (i < keys.length) { @@ -317,6 +319,8 @@ case class MapFromArrays(left: Expression, right: Expression) valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } + override def stateful: Boolean = true + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { @@ -372,6 +376,7 @@ object CreateStruct { // alias name inside CreateNamedStruct. case (u: UnresolvedAttribute, _) => Seq(Literal(u.nameParts.last), u) case (u @ UnresolvedExtractValue(_, e: Literal), _) if e.dataType == StringType => Seq(e, u) + case (a: Alias, _) => Seq(Literal(a.name), a) case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) case (e: NamedExpression, _) => Seq(NamePlaceholder, e) case (e, index) => Seq(Literal(s"col${index + 1}"), e) @@ -562,6 +567,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E this(child, Literal(","), Literal(":")) } + override def stateful: Boolean = true + override def first: Expression = text override def second: Expression = pairDelim override def third: Expression = keyValueDelim diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index fec1df108bccf..5b10b401af98d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -918,6 +918,8 @@ case class TransformKeys( override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull) + override def stateful: Boolean = true + override def checkInputDataTypes(): TypeCheckResult = { TypeUtils.checkForMapKeyType(function.dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index add59a38b7201..b9a2cb348e380 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1509,7 +1509,7 @@ abstract class RoundBase(child: Expression, scale: Expression, DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> "scala", + "inputName" -> "scale", "inputType" -> toSQLType(scale.dataType), "inputExpr" -> toSQLExpr(scale))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 2d4f0438db760..9dcca65efe5a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -86,6 +86,7 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $builder.parse(${eval.value}); + | ${ev.isNull} = ${ev.isNull} || (${ev.value} == null); |} """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 46f8e1a9d673d..0da6d171a1dd4 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2419,24 +2419,41 @@ case class Chr(child: Expression) """, since = "1.5.0", group = "string_funcs") -case class Base64(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { +case class Base64(child: Expression, chunkBase64: Boolean) + extends UnaryExpression with RuntimeReplaceable with ImplicitCastInputTypes { + + def this(expr: Expression) = this(expr, SQLConf.get.chunkBase64StringEnabled) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) - protected override def nullSafeEval(bytes: Any): Any = { - UTF8String.fromBytes(JBase64.getMimeEncoder.encode(bytes.asInstanceOf[Array[Byte]])) - } + override def replacement: Expression = StaticInvoke( + classOf[Base64], + dataType, + "encode", + Seq(child, Literal(chunkBase64, BooleanType)), + Seq(BinaryType, BooleanType), + returnNullable = false) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (child) => { - s"""${ev.value} = UTF8String.fromBytes( - ${classOf[JBase64].getName}.getMimeEncoder().encode($child)); - """}) - } + override def toString: String = s"$prettyName($child)" - override protected def withNewChildInternal(newChild: Expression): Base64 = copy(child = newChild) + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +object Base64 { + def apply(expr: Expression): Base64 = new Base64(expr) + + private lazy val nonChunkEncoder = JBase64.getMimeEncoder(-1, Array()) + + def encode(input: Array[Byte], chunkBase64: Boolean): UTF8String = { + val encoder = if (chunkBase64) { + JBase64.getMimeEncoder + } else { + nonChunkEncoder + } + UTF8String.fromBytes(encoder.encode(input)) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 47b37a5edeba8..0dd14d313e17e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -115,7 +115,7 @@ object UrlCodec { UTF8String.fromString(URLDecoder.decode(src.toString, enc.toString)) } catch { case e: IllegalArgumentException => - throw QueryExecutionErrors.illegalUrlError(src) + throw QueryExecutionErrors.illegalUrlError(src, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 50c98c01645d9..a4ce78d1bb6d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -850,7 +850,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow // for each partition. override def checkInputDataTypes(): TypeCheckResult = { if (!buckets.foldable) { - DataTypeMismatch( + return DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> "buckets", @@ -861,7 +861,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow } if (buckets.dataType != IntegerType) { - DataTypeMismatch( + return DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> "1", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 2ed13944be9af..105931cf4f803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -239,13 +239,16 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract { Examples: > SELECT _FUNC_('b1b2b3c1c2','a/b/text()'); ["b1","b2","b3"] + > SELECT _FUNC_('b1b2b3c1c2','a/b'); + [null,null,null] """, since = "2.0.0", group = "xml_funcs") // scalastyle:on line.size.limit case class XPathList(xml: Expression, path: Expression) extends XPathExtract { override def prettyName: String = "xpath" - override def dataType: DataType = ArrayType(StringType, containsNull = false) + + override def dataType: DataType = ArrayType(StringType) override def nullSafeEval(xml: Any, path: Any): Any = { val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 2f818fecad93a..ceced9313940a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.connector.catalog.CatalogManager + /** * An identifier that optionally specifies a database. * @@ -107,8 +109,23 @@ case class TableIdentifier(table: String, database: Option[String], catalog: Opt } /** A fully qualified identifier for a table (i.e., database.tableName) */ -case class QualifiedTableName(database: String, name: String) { - override def toString: String = s"$database.$name" +case class QualifiedTableName(catalog: String, database: String, name: String) { + /** Two argument ctor for backward compatibility. */ + def this(database: String, name: String) = this( + catalog = CatalogManager.SESSION_CATALOG_NAME, + database = database, + name = name) + + override def toString: String = s"$catalog.$database.$name" +} + +object QualifiedTableName { + def apply(catalog: String, database: String, name: String): QualifiedTableName = { + new QualifiedTableName(catalog, database, name) + } + + def apply(database: String, name: String): QualifiedTableName = + new QualifiedTableName(database = database, name = name) } object TableIdentifier { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index f14f70532e659..e043230c683c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -180,7 +180,18 @@ class JacksonParser( // val st = at.elementType.asInstanceOf[StructType] val fieldConverters = st.map(_.dataType).map(makeConverter).toArray - Some(InternalRow(new GenericArrayData(convertObject(parser, st, fieldConverters).toArray))) + + val res = try { + convertObject(parser, st, fieldConverters) + } catch { + case err: PartialResultException => + throw PartialArrayDataResultException( + new GenericArrayData(Seq(err.partialResult)), + err.cause + ) + } + + Some(InternalRow(new GenericArrayData(res.toArray))) } } @@ -497,9 +508,9 @@ class JacksonParser( try { values += fieldConverter.apply(parser) } catch { - case PartialResultException(row, cause) if enablePartialResults => - badRecordException = badRecordException.orElse(Some(cause)) - values += row + case err: PartialValueException if enablePartialResults => + badRecordException = badRecordException.orElse(Some(err.cause)) + values += err.partialResult case NonFatal(e) if enablePartialResults => badRecordException = badRecordException.orElse(Some(e)) parser.skipChildren() @@ -534,9 +545,9 @@ class JacksonParser( if (isRoot && v == null) throw QueryExecutionErrors.rootConverterReturnNullError() values += v } catch { - case PartialResultException(row, cause) if enablePartialResults => - badRecordException = badRecordException.orElse(Some(cause)) - values += row + case err: PartialValueException if enablePartialResults => + badRecordException = badRecordException.orElse(Some(err.cause)) + values += err.partialResult } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 5385afe8c9353..f6d32f39f64ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { @@ -53,6 +54,9 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { isParsing = true, forTimestampNTZ = true) + private val isDefaultNTZ = SQLConf.get.timestampType == TimestampNTZType + private val legacyMode = SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY + private def handleJsonErrorsByParseMode(parseMode: ParseMode, columnNameOfCorruptRecord: String, e: Throwable): Option[StructType] = { parseMode match { @@ -150,12 +154,28 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { } if (options.prefersDecimal && decimalTry.isDefined) { decimalTry.get - } else if (options.inferTimestamp && + } else if (options.inferTimestamp) { + // For text-based format, it's ambiguous to infer a timestamp string without timezone, as + // it can be both TIMESTAMP LTZ and NTZ. To avoid behavior changes with the new support + // of NTZ, here we only try to infer NTZ if the config is set to use NTZ by default. + if (isDefaultNTZ && timestampNTZFormatter.parseWithoutTimeZoneOptional(field, false).isDefined) { - SQLConf.get.timestampType - } else if (options.inferTimestamp && - timestampFormatter.parseOptional(field).isDefined) { - TimestampType + TimestampNTZType + } else if (timestampFormatter.parseOptional(field).isDefined) { + TimestampType + } else if (legacyMode) { + val utf8Value = UTF8String.fromString(field) + // There was a mistake that we use TIMESTAMP NTZ parser to infer LTZ type with legacy + // mode. The mistake makes it easier to infer TIMESTAMP LTZ type and we have to keep + // this behavior now. See SPARK-46769 for more details. + if (SparkDateTimeUtils.stringToTimestampWithoutTimeZone(utf8Value, false).isDefined) { + TimestampType + } else { + StringType + } + } else { + StringType + } } else { StringType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala index 261be2914630e..04204c6a2e108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala @@ -52,23 +52,33 @@ object InferWindowGroupLimit extends Rule[LogicalPlan] with PredicateHelper { if (limits.nonEmpty) Some(limits.min) else None } - private def support( + /** + * All window expressions should use the same expanding window, so that + * we can safely do the early stop. + */ + private def isExpandingWindow( windowExpression: NamedExpression): Boolean = windowExpression match { - case Alias(WindowExpression(_: Rank | _: DenseRank | _: RowNumber, WindowSpecDefinition(_, _, + case Alias(WindowExpression(_, WindowSpecDefinition(_, _, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), _) => true case _ => false } + private def support(windowFunction: Expression): Boolean = windowFunction match { + case _: Rank | _: DenseRank | _: RowNumber => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = { if (conf.windowGroupLimitThreshold == -1) return plan plan.transformWithPruning(_.containsAllPatterns(FILTER, WINDOW), ruleId) { case filter @ Filter(condition, window @ Window(windowExpressions, partitionSpec, orderSpec, child)) - if !child.isInstanceOf[WindowGroupLimit] && windowExpressions.exists(support) && + if !child.isInstanceOf[WindowGroupLimit] && windowExpressions.forall(isExpandingWindow) && orderSpec.nonEmpty => val limits = windowExpressions.collect { - case alias @ Alias(WindowExpression(rankLikeFunction, _), _) if support(alias) => + case alias @ Alias(WindowExpression(rankLikeFunction, _), _) + if support(rankLikeFunction) => extractLimits(condition, alias.toAttribute).map((_, rankLikeFunction)) }.flatten diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 8d7ff4cbf163d..69adf3e15cf44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -79,41 +79,51 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { * - The number of incoming references to the CTE. This includes references from * other CTEs and regular places. * - A mutable inner map that tracks outgoing references (counts) to other CTEs. - * @param outerCTEId While collecting the map we use this optional CTE id to identify the - * current outer CTE. + * @param collectCTERefs A function to collect CTE references so that the caller side can do some + * bookkeeping work. */ def buildCTEMap( plan: LogicalPlan, cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], - outerCTEId: Option[Long] = None): Unit = { + collectCTERefs: CTERelationRef => Unit = _ => ()): Unit = { plan match { case WithCTE(child, cteDefs) => cteDefs.foreach { cteDef => cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0)) } cteDefs.foreach { cteDef => - buildCTEMap(cteDef, cteMap, Some(cteDef.id)) + buildCTEMap(cteDef, cteMap, ref => { + // A CTE relation can references CTE relations defined before it in the same `WithCTE`. + // Here we update the out-going-ref-count for it, in case this CTE relation is not + // referenced at all and can be optimized out, and we need to decrease the ref counts + // for CTE relations that are referenced by it. + if (cteDefs.exists(_.id == ref.cteId)) { + val (_, _, outerRefMap) = cteMap(cteDef.id) + outerRefMap(ref.cteId) += 1 + } + // Similarly, a CTE relation can reference CTE relations defined in the outer `WithCTE`. + // Here we call the `collectCTERefs` function so that the outer CTE can also update the + // out-going-ref-count if needed. + collectCTERefs(ref) + }) } - buildCTEMap(child, cteMap, outerCTEId) + buildCTEMap(child, cteMap, collectCTERefs) case ref: CTERelationRef => val (cteDef, refCount, refMap) = cteMap(ref.cteId) cteMap(ref.cteId) = (cteDef, refCount + 1, refMap) - outerCTEId.foreach { cteId => - val (_, _, outerRefMap) = cteMap(cteId) - outerRefMap(ref.cteId) += 1 - } + collectCTERefs(ref) case _ => if (plan.containsPattern(CTE)) { plan.children.foreach { child => - buildCTEMap(child, cteMap, outerCTEId) + buildCTEMap(child, cteMap, collectCTERefs) } plan.expressions.foreach { expr => if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) { expr.foreach { - case e: SubqueryExpression => buildCTEMap(e.plan, cteMap, outerCTEId) + case e: SubqueryExpression => buildCTEMap(e.plan, cteMap, collectCTERefs) case _ => } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 6184160829ba6..ff0bc5e66d755 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -381,7 +381,8 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] GetStructField( ScalarSubquery( - CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output), + CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output, + subqueryCTE.isStreaming), exprId = ssr.exprId), ssr.headerIndex) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 5d4fcf772b8fc..778813e4e9c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -217,6 +217,11 @@ object NestedColumnAliasing { case _ => false } + private def canAlias(ev: Expression): Boolean = { + // we can not alias the attr from lambda variable whose expr id is not available + !ev.exists(_.isInstanceOf[NamedLambdaVariable]) && ev.references.size == 1 + } + /** * Returns two types of expressions: * - Root references that are individually accessed @@ -225,11 +230,11 @@ object NestedColumnAliasing { */ private def collectRootReferenceAndExtractValue(e: Expression): Seq[Expression] = e match { case _: AttributeReference => Seq(e) - case GetStructField(_: ExtractValue | _: AttributeReference, _, _) => Seq(e) + case GetStructField(_: ExtractValue | _: AttributeReference, _, _) if canAlias(e) => Seq(e) case GetArrayStructFields(_: MapValues | _: MapKeys | _: ExtractValue | - _: AttributeReference, _, _, _, _) => Seq(e) + _: AttributeReference, _, _, _, _) if canAlias(e) => Seq(e) case es if es.children.nonEmpty => es.children.flatMap(collectRootReferenceAndExtractValue) case _ => Seq.empty } @@ -248,13 +253,8 @@ object NestedColumnAliasing { val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]() exprList.foreach { e => extractor(e).foreach { - // we can not alias the attr from lambda variable whose expr id is not available - case ev: ExtractValue if !ev.exists(_.isInstanceOf[NamedLambdaVariable]) => - if (ev.references.size == 1) { - nestedFieldReferences.append(ev) - } + case ev: ExtractValue => nestedFieldReferences.append(ev) case ar: AttributeReference => otherRootReferences.append(ar) - case _ => // ignore } } val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala index 4347137bf68b8..04cc230f99b44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala @@ -112,9 +112,10 @@ object OptimizeCsvJsonExprs extends Rule[LogicalPlan] { val prunedSchema = StructType(Array(schema(ordinal))) g.copy(child = j.copy(schema = prunedSchema), ordinal = 0) - case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _) - if schema.elementType.asInstanceOf[StructType].length > 1 && j.options.isEmpty => - val prunedSchema = ArrayType(StructType(Array(g.field)), g.containsNull) + case g @ GetArrayStructFields(j @ JsonToStructs(ArrayType(schema: StructType, _), + _, _, _), _, ordinal, _, _) if schema.length > 1 && j.options.isEmpty => + // Obtain the pruned schema by picking the `ordinal` field of the struct. + val prunedSchema = ArrayType(StructType(Array(schema(ordinal))), g.containsNull) g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala index 83646611578cb..61c08eb8f8b6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.internal.SQLConf /** * The rule is applied both normal and AQE Optimizer. It optimizes plan using max rows: @@ -31,19 +32,37 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ * it's grouping only(include the rewritten distinct plan), convert aggregate to project * - if the max rows of the child of aggregate is less than or equal to 1, * set distinct to false in all aggregate expression + * + * Note: the rule should not be applied to streaming source, since the number of rows it sees is + * just for current microbatch. It does not mean the streaming source will ever produce max 1 + * rows during lifetime of the query. Suppose the case: the streaming query has a case where + * batch 0 runs with empty data in streaming source A which triggers the rule with Aggregate, + * and batch 1 runs with several data in streaming source A which no longer trigger the rule. + * In the above scenario, this could fail the query as stateful operator is expected to be planned + * for every batches whereas here it is planned "selectively". */ object OptimizeOneRowPlan extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { + val enableForStreaming = conf.getConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED) + plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) { - case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => child - case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) => child - case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) => + case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) && + isChildEligible(child, enableForStreaming) => child + case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) && + isChildEligible(child, enableForStreaming) => child + case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) && + isChildEligible(child, enableForStreaming) => Project(agg.aggregateExpressions, child) - case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) => + case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) && + isChildEligible(agg.child, enableForStreaming) => agg.transformExpressions { case aggExpr: AggregateExpression if aggExpr.isDistinct => aggExpr.copy(isDistinct = false) } } } + + private def isChildEligible(child: LogicalPlan, enableForStreaming: Boolean): Boolean = { + enableForStreaming || !child.isStreaming + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bb2a86556c031..9e1d264be98f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression.hasCorrelatedSubquery import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{RepartitionOperation, _} @@ -287,7 +288,9 @@ abstract class Optimizer(catalogManager: CatalogManager) ComputeCurrentTime, ReplaceCurrentLike(catalogManager), SpecialDatetimeValues, - RewriteAsOfJoin) + RewriteAsOfJoin, + EvalInlineTables + ) override def apply(plan: LogicalPlan): LogicalPlan = { rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } @@ -574,10 +577,20 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { } case _ => + val subQueryAttributes = if (conf.getConf(SQLConf + .EXCLUDE_SUBQUERY_EXP_REFS_FROM_REMOVE_REDUNDANT_ALIASES)) { + // Collect the references for all the subquery expressions in the plan. + AttributeSet.fromAttributeSets(plan.expressions.collect { + case e: SubqueryExpression => e.references + }) + } else { + AttributeSet.empty + } + // Remove redundant aliases in the subtree(s). val currentNextAttrPairs = mutable.Buffer.empty[(Attribute, Attribute)] val newNode = plan.mapChildren { child => - val newChild = removeRedundantAliases(child, excluded) + val newChild = removeRedundantAliases(child, excluded ++ subQueryAttributes) currentNextAttrPairs ++= createAttributeMapping(child, newChild) newChild } @@ -1152,11 +1165,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { * in aggregate if they are also part of the grouping expressions. Otherwise the plan * after subquery rewrite will not be valid. */ - private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = { - p.projectList.forall(_.collect { - case s: ScalarSubquery if s.outerAttrs.nonEmpty => s - }.isEmpty) - } + private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = + !p.projectList.exists(hasCorrelatedSubquery) def buildCleanedProjectList( upper: Seq[NamedExpression], @@ -1649,16 +1659,19 @@ object EliminateSorts extends Rule[LogicalPlan] { * 3) by eliminating the always-true conditions given the constraints on the child's output. */ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { + private def shouldApply(child: LogicalPlan): Boolean = + SQLConf.get.getConf(SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN) || !child.isStreaming + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(FILTER), ruleId) { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child // If the filter condition always evaluate to null or false, // replace the input with an empty relation. - case Filter(Literal(null, _), child) => - LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming) - case Filter(Literal(false, BooleanType), child) => - LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming) + case Filter(Literal(null, _), child) if shouldApply(child) => + LocalRelation(child.output, data = Seq.empty, isStreaming = child.isStreaming) + case Filter(Literal(false, BooleanType), child) if shouldApply(child) => + LocalRelation(child.output, data = Seq.empty, isStreaming = child.isStreaming) // If any deterministic condition is guaranteed to be true given the constraints on the child's // output, remove the condition case f @ Filter(fc, p: LogicalPlan) => @@ -2183,11 +2196,15 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { case d @ Deduplicate(keys, child) if !child.isStreaming => val keyExprIds = keys.map(_.exprId) + val generatedAliasesMap = new mutable.HashMap[Attribute, Alias](); val aggCols = child.output.map { attr => if (keyExprIds.contains(attr.exprId)) { attr } else { - Alias(new First(attr).toAggregateExpression(), attr.name)() + // Keep track of the generated aliases to avoid generating multiple aliases + // for the same attribute (in case the attribute is duplicated) + generatedAliasesMap.getOrElseUpdate(attr, + Alias(new First(attr).toAggregateExpression(), attr.name)()) } } // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index fd7a87087ddd2..738d547d4fb61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -65,6 +65,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + protected def canExecuteWithoutJoin(plan: LogicalPlan): Boolean = true + protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Union if p.children.exists(isEmpty) => val newChildren = p.children.filterNot(isEmpty) @@ -109,20 +111,22 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p) case LeftSemi if isRightEmpty | isFalseCondition => empty(p) - case LeftAnti if isRightEmpty | isFalseCondition => p.left + case LeftAnti if (isRightEmpty | isFalseCondition) && canExecuteWithoutJoin(p.left) => + p.left case FullOuter if isLeftEmpty && isRightEmpty => empty(p) - case LeftOuter | FullOuter if isRightEmpty => + case LeftOuter | FullOuter if isRightEmpty && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) case RightOuter if isRightEmpty => empty(p) - case RightOuter | FullOuter if isLeftEmpty => + case RightOuter | FullOuter if isLeftEmpty && canExecuteWithoutJoin(p.right) => Project(nullValueProjectList(p.left) ++ p.right.output, p.right) - case LeftOuter if isFalseCondition => + case LeftOuter if isFalseCondition && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) - case RightOuter if isFalseCondition => + case RightOuter if isFalseCondition && canExecuteWithoutJoin(p.right) => Project(nullValueProjectList(p.left) ++ p.right.output, p.right) case _ => p } - } else if (joinType == LeftSemi && conditionOpt.isEmpty && nonEmpty(p.right)) { + } else if (joinType == LeftSemi && conditionOpt.isEmpty && + nonEmpty(p.right) && canExecuteWithoutJoin(p.left)) { p.left } else if (joinType == LeftAnti && conditionOpt.isEmpty && nonEmpty(p.right)) { empty(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala index e643a1af363a1..aa13e6a67c510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala @@ -141,7 +141,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] { cteDef } - case cteRef @ CTERelationRef(cteId, _, output, _) => + case cteRef @ CTERelationRef(cteId, _, output, _, _) => val (cteDef, _, _, newAttrSet) = cteMap(cteId) if (needsPruning(cteDef.child, newAttrSet)) { val indices = newAttrSet.toSeq.map(cteDef.output.indexOf) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index da3cf782f6682..5aef82b64ed32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -197,6 +197,17 @@ import org.apache.spark.util.collection.Utils * techniques. */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { + private def mustRewrite( + distinctAggs: Seq[AggregateExpression], + groupingExpressions: Seq[Expression]): Boolean = { + // If there are any distinct AggregateExpressions with filter, we need to rewrite the query. + // Also, if there are no grouping expressions and all distinct aggregate expressions are + // foldable, we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). Without this case, + // non-grouping aggregation queries with distinct aggregate expressions will be incorrectly + // handled by the aggregation strategy, causing wrong results when working with empty tables. + distinctAggs.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && + distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable))) + } private def mayNeedtoRewrite(a: Aggregate): Boolean = { val aggExpressions = collectAggregateExprs(a) @@ -204,8 +215,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // We need at least two distinct aggregates or the single distinct aggregate group exists filter // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. - // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) + distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( @@ -236,7 +246,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { + if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -390,13 +400,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { (distinctAggOperatorMap.flatMap(_._2) ++ regularAggOperatorMap.map(e => (e._1, e._3))).toMap + val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable) val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case e: Expression => // The same GROUP BY clauses can have different forms (different names for instance) in // the groupBy and aggregate expressions of an aggregate. This makes a map lookup // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap + groupByMapNonFoldable .find(ge => e.semanticEquals(ge._1)) .map(_._2) .getOrElse(transformations.getOrElse(e, e)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index d4f0f72c9352b..c0995b273bd03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -78,7 +78,7 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => try { - Literal.create(e.eval(EmptyRow), e.dataType) + Literal.create(e.freshCopyIfContainsStatefulExpression().eval(EmptyRow), e.dataType) } catch { case NonFatal(_) if isConditionalBranch => // When doing constant folding inside conditional expressions, we should not fail @@ -934,7 +934,14 @@ object FoldablePropagation extends Rule[LogicalPlan] { val newFoldableMap = collectFoldables(newProject.projectList) (newProject, newFoldableMap) - case a: Aggregate => + // FoldablePropagation rule can produce incorrect optimized plan for streaming queries. + // This is because the optimizer can replace the grouping expressions, or join column + // with a literal value if the grouping key is constant for the micro-batch. However, + // as Streaming queries also read from the StateStore, this optimization also + // overwrites any keys read from State Store. We need to disable this optimization + // until we can make optimizer aware of Streaming state store. The State Store nodes + // are currently added in the Physical plan. + case a: Aggregate if !a.isStreaming => val (newChild, foldableMap) = propagateFoldables(a.child) val newAggregate = replaceFoldable(a.withNewChildren(Seq(newChild)).asInstanceOf[Aggregate], foldableMap) @@ -971,7 +978,14 @@ object FoldablePropagation extends Rule[LogicalPlan] { // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j: Join => + // FoldablePropagation rule can produce incorrect optimized plan for streaming queries. + // This is because the optimizer can replace the grouping expressions, or join column + // with a literal value if the grouping key is constant for the micro-batch. However, + // as Streaming queries also read from the StateStore, this optimization also + // overwrites any keys read from State Store. We need to disable this optimization + // until we can make optimizer aware of Streaming state store. The State Store nodes + // are currently added in the Physical plan. + case j: Join if !j.left.isStreaming || !j.right.isStreaming => val (newChildren, foldableMaps) = j.children.map(propagateFoldables).unzip val foldableMap = AttributeMap( foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ _.baseMap.values)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 466781fa1def7..d7efc16a514bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -19,7 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import java.time.{Instant, LocalDateTime} +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{CastSupport, ResolvedInlineTable} +import org.apache.spark.sql.catalyst.analysis.ResolveInlineTables.prepareForEval import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -27,6 +32,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.trees.TreePatternBits import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros} +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -71,6 +77,33 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { } } +/** + * Computes expressions in inline tables. This rule is supposed to be called at the very end + * of the analysis phase, given that all the expressions need to be fully resolved/replaced + * at this point. + */ +object EvalInlineTables extends Rule[LogicalPlan] with CastSupport { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformDownWithSubqueriesAndPruning(_.containsPattern(INLINE_TABLE_EVAL)) { + case table: ResolvedInlineTable => + val newRows: Seq[InternalRow] = + table.rows.map { row => InternalRow.fromSeq(row.map { e => + try { + prepareForEval(e).eval() + } catch { + case NonFatal(ex) => + table.failAnalysis( + errorClass = "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", + messageParameters = Map("sqlExpr" -> toSQLExpr(e)), + cause = ex) + }}) + } + + LocalRelation(table.output, newRows) + } + } +} + /** * Computes the current date and time to make sure we return the same result in a single query. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 91cd838ad617a..ee20053157816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -118,16 +118,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { withSubquery.foldLeft(newFilter) { case (p, Exists(sub, _, _, conditions, subHint)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - buildJoin(outerPlan, sub, LeftSemi, joinCond, subHint) + val join = buildJoin(outerPlan, sub, LeftSemi, joinCond, subHint) + Project(p.output, join) case (p, Not(Exists(sub, _, _, conditions, subHint))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - buildJoin(outerPlan, sub, LeftAnti, joinCond, subHint) + val join = buildJoin(outerPlan, sub, LeftAnti, joinCond, subHint) + Project(p.output, join) case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) => // Deduplicate conflicting attributes if any. val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) - Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint(None, subHint)) + val join = Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint(None, subHint)) + Project(p.output, join) case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 83938632e534f..2b600743e1bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors} @@ -66,12 +66,25 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { builder: Seq[String] => LogicalPlan): LogicalPlan = { val exprCtx = ctx.expression if (exprCtx != null) { - PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, builder) + PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, Nil, + (ident, _) => builder(ident)) } else { builder.apply(visitMultipartIdentifier(ctx.multipartIdentifier)) } } + protected def withIdentClause( + ctx: IdentifierReferenceContext, + otherPlans: Seq[LogicalPlan], + builder: (Seq[String], Seq[LogicalPlan]) => LogicalPlan): LogicalPlan = { + val exprCtx = ctx.expression + if (exprCtx != null) { + PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, otherPlans, builder) + } else { + builder.apply(visitMultipartIdentifier(ctx.multipartIdentifier), otherPlans) + } + } + protected def withIdentClause( ctx: IdentifierReferenceContext, builder: Seq[String] => Expression): Expression = { @@ -85,12 +98,13 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { protected def withFuncIdentClause( ctx: FunctionNameContext, - builder: Seq[String] => LogicalPlan): LogicalPlan = { + otherPlans: Seq[LogicalPlan], + builder: (Seq[String], Seq[LogicalPlan]) => LogicalPlan): LogicalPlan = { val exprCtx = ctx.expression if (exprCtx != null) { - PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, builder) + PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, otherPlans, builder) } else { - builder.apply(getFunctionMultiparts(ctx)) + builder.apply(getFunctionMultiparts(ctx), otherPlans) } } @@ -320,12 +334,12 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { case table: InsertIntoTableContext => val (relationCtx, cols, partition, ifPartitionNotExists, byName) = visitInsertIntoTable(table) - withIdentClause(relationCtx, ident => { + withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident), + createUnresolvedRelation(relationCtx, ident, Seq(TableWritePrivilege.INSERT)), partition, cols, - query, + otherPlans.head, overwrite = false, ifPartitionNotExists, byName) @@ -333,21 +347,23 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { case table: InsertOverwriteTableContext => val (relationCtx, cols, partition, ifPartitionNotExists, byName) = visitInsertOverwriteTable(table) - withIdentClause(relationCtx, ident => { + withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident), + createUnresolvedRelation(relationCtx, ident, + Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), partition, cols, - query, + otherPlans.head, overwrite = true, ifPartitionNotExists, byName) }) case ctx: InsertIntoReplaceWhereContext => - withIdentClause(ctx.identifierReference, ident => { + withIdentClause(ctx.identifierReference, Seq(query), (ident, otherPlans) => { OverwriteByExpression.byPosition( - createUnresolvedRelation(ctx.identifierReference, ident), - query, + createUnresolvedRelation(ctx.identifierReference, ident, + Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), + otherPlans.head, expression(ctx.whereClause().booleanExpression())) }) case dir: InsertOverwriteDirContext => @@ -425,7 +441,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitDeleteFromTable( ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference) + val table = createUnresolvedRelation( + ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE)) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val predicate = if (ctx.whereClause() != null) { @@ -437,7 +454,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference) + val table = createUnresolvedRelation( + ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE)) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val assignments = withAssignments(ctx.setClause().assignmentList()) @@ -459,10 +477,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { - val targetTable = createUnresolvedRelation(ctx.target) - val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") - val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) - val sourceTableOrQuery = if (ctx.source != null) { createUnresolvedRelation(ctx.source) } else if (ctx.sourceQuery != null) { @@ -492,7 +506,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { s"Unrecognized matched action: ${clause.matchedAction().getText}") } } - } + }.toSeq val notMatchedActions = ctx.notMatchedClause().asScala.map { clause => { if (clause.notMatchedAction().INSERT() != null) { @@ -513,7 +527,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { s"Unrecognized matched action: ${clause.notMatchedAction().getText}") } } - } + }.toSeq val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map { clause => { val notMatchedBySourceAction = clause.notMatchedBySourceAction() @@ -528,7 +542,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}") } } - } + }.toSeq if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) { throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx) } @@ -547,13 +561,19 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx) } + val targetTable = createUnresolvedRelation( + ctx.target, + writePrivileges = MergeIntoTable.getWritePrivileges( + matchedActions, notMatchedActions, notMatchedBySourceActions)) + val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") + val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) MergeIntoTable( aliasedTarget, aliasedSource, mergeCondition, - matchedActions.toSeq, - notMatchedActions.toSeq, - notMatchedBySourceActions.toSeq) + matchedActions, + notMatchedActions, + notMatchedBySourceActions) } /** @@ -787,7 +807,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { // Create the attributes. val (attributes, schemaLess) = if (transformClause.colTypeList != null) { // Typed return columns. - (DataTypeUtils.toAttributes(createSchema(transformClause.colTypeList)), false) + val schema = createSchema(transformClause.colTypeList) + val replacedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) + (DataTypeUtils.toAttributes(replacedSchema), false) } else if (transformClause.identifierSeq != null) { // Untyped return columns. val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name => @@ -1291,7 +1313,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { * Create an Unpivot column. */ override def visitUnpivotColumn(ctx: UnpivotColumnContext): NamedExpression = withOrigin(ctx) { - UnresolvedAlias(UnresolvedAttribute(visitMultipartIdentifier(ctx.multipartIdentifier))) + UnresolvedAttribute(visitMultipartIdentifier(ctx.multipartIdentifier)) } /** @@ -1596,7 +1618,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { withFuncIdentClause( func.functionName, - ident => { + Nil, + (ident, _) => { if (ident.length > 1) { throw QueryParsingErrors.invalidTableValuedFunctionNameError(ident, ctx) } @@ -2223,13 +2246,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } } - /** - * Create an expression for the IDENTIFIER() clause. - */ - override def visitIdentifierClause(ctx: IdentifierClauseContext): Expression = withOrigin(ctx) { - ExpressionWithUnresolvedIdentifier(expression(ctx.expression), UnresolvedAttribute(_)) - } - /** * Create a (windowed) Function expression. */ @@ -2251,19 +2267,31 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val filter = Option(ctx.where).map(expression(_)) val ignoreNulls = Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE).getOrElse(false) - val funcCtx = ctx.functionName - val func = withFuncIdentClause( - funcCtx, - ident => UnresolvedFunction(ident, arguments, isDistinct, filter, ignoreNulls) - ) - // Check if the function is evaluated in a windowed context. - ctx.windowSpec match { - case spec: WindowRefContext => - UnresolvedWindowExpression(func, visitWindowRef(spec)) - case spec: WindowDefContext => - WindowExpression(func, visitWindowDef(spec)) - case _ => func + // Is this an IDENTIFIER clause instead of a function call? + if (ctx.functionName.identFunc != null && + arguments.length == 1 && // One argument + ctx.setQuantifier == null && // No other clause + ctx.where == null && + ctx.nullsOption == null && + ctx.windowSpec == null) { + ExpressionWithUnresolvedIdentifier(arguments.head, UnresolvedAttribute(_)) + } else { + // It's a function call + val funcCtx = ctx.functionName + val func = withFuncIdentClause( + funcCtx, + ident => UnresolvedFunction(ident, arguments, isDistinct, filter, ignoreNulls) + ) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(func, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(func, visitWindowDef(spec)) + case _ => func + } } } @@ -2771,16 +2799,23 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { * Create an [[UnresolvedRelation]] from an identifier reference. */ private def createUnresolvedRelation( - ctx: IdentifierReferenceContext): LogicalPlan = withOrigin(ctx) { - withIdentClause(ctx, UnresolvedRelation(_)) + ctx: IdentifierReferenceContext, + writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) { + withIdentClause(ctx, parts => { + val relation = new UnresolvedRelation(parts, isStreaming = false) + relation.requireWritePrivileges(writePrivileges) + }) } /** * Create an [[UnresolvedRelation]] from a multi-part identifier. */ private def createUnresolvedRelation( - ctx: ParserRuleContext, ident: Seq[String]): UnresolvedRelation = withOrigin(ctx) { - UnresolvedRelation(ident) + ctx: ParserRuleContext, + ident: Seq[String], + writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) { + val relation = new UnresolvedRelation(ident, isStreaming = false) + relation.requireWritePrivileges(writePrivileges) } /** @@ -3269,7 +3304,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { ctx: ExpressionPropertyListContext): OptionList = { val options = ctx.expressionProperty.asScala.map { property => val key: String = visitPropertyKey(property.key) - val value: Expression = Option(property.value).map(expression).orNull + val value: Expression = Option(property.value).map(expression).getOrElse { + operationNotAllowed(s"A value must be specified for the key: $key.", ctx) + } key -> value }.toSeq OptionList(options) @@ -4577,7 +4614,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { if (query.isDefined) { CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options) } else { - CacheTable(createUnresolvedRelation(ctx.identifierReference, ident), ident, isLazy, options) + CacheTable( + createUnresolvedRelation(ctx.identifierReference, ident, writePrivileges = Nil), + ident, isLazy, options) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index aee4790eb42aa..12ee0274fd7a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import java.util.IdentityHashMap + import scala.collection.mutable import org.apache.spark.sql.AnalysisException @@ -429,7 +431,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] override def verboseString(maxFields: Int): String = simpleString(maxFields) override def simpleStringWithNodeId(): String = { - val operatorId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + val operatorId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id") + .getOrElse("unknown") s"$nodeName ($operatorId)".trim } @@ -449,7 +452,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } protected def formattedNodeName: String = { - val opId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + val opId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id") + .getOrElse("unknown") val codegenId = getTagValue(QueryPlan.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") s"($opId) $nodeName$codegenId" @@ -626,9 +630,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } object QueryPlan extends PredicateHelper { - val OP_ID_TAG = TreeNodeTag[Int]("operatorId") val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") + /** + * A thread local map to store the mapping between the query plan and the query plan id. + * The scope of this thread local is within ExplainUtils.processPlan. The reason we define it here + * is because [[ QueryPlan ]] also needs this, and it doesn't have access to `execution` package + * from `catalyst`. + */ + val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = ThreadLocal.withInitial(() => + new IdentityHashMap[QueryPlan[_], Int]()) + /** * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 1088655f60cd4..a901fa5a72c5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -69,6 +69,8 @@ trait FunctionBuilderBase[T] { } def build(funcName: String, expressions: Seq[Expression]): T + + def supportsLambda: Boolean = false } object NamedParametersSupport { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 374eb070db1c9..7fe8bd356ea94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -116,7 +116,9 @@ abstract class LogicalPlan def resolve(schema: StructType, resolver: Resolver): Seq[Attribute] = { schema.map { field => resolve(field.name :: Nil, resolver).map { - case a: AttributeReference => a + case a: AttributeReference => + // Keep the metadata in given schema. + a.withMetadata(field.metadata) case _ => throw QueryExecutionErrors.resolveCannotHandleNestedSchema(this) }.getOrElse { throw QueryCompilationErrors.cannotResolveAttributeError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4bb830662a33f..f76e698a64005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -235,6 +235,20 @@ object Project { } } +case class DataFrameDropColumns(dropList: Seq[Expression], child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = Nil + + override def maxRows: Option[Long] = child.maxRows + override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition + + final override val nodePatterns: Seq[TreePattern] = Seq(DF_DROP_COLUMNS) + + override lazy val resolved: Boolean = false + + override protected def withNewChildInternal(newChild: LogicalPlan): DataFrameDropColumns = + copy(child = newChild) +} + /** * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional @@ -839,6 +853,7 @@ case class CTERelationRef( cteId: Long, _resolved: Boolean, override val output: Seq[Attribute], + override val isStreaming: Boolean, statsOpt: Option[Statistics] = None) extends LeafNode with MultiInstanceRelation { final override val nodePatterns: Seq[TreePattern] = Seq(CTE) @@ -1048,10 +1063,12 @@ case class Range( if (numElements == 0) { Statistics(sizeInBytes = 0, rowCount = Some(0)) } else { - val (minVal, maxVal) = if (step > 0) { - (start, start + (numElements - 1) * step) + val (minVal, maxVal) = if (!numElements.isValidLong) { + (None, None) + } else if (step > 0) { + (Some(start), Some(start + (numElements.toLong - 1) * step)) } else { - (start + (numElements - 1) * step, start) + (Some(start + (numElements.toLong - 1) * step), Some(start)) } val histogram = if (conf.histogramEnabled) { @@ -1062,8 +1079,8 @@ case class Range( val colStat = ColumnStat( distinctCount = Some(numElements), - max = Some(maxVal), - min = Some(minVal), + max = maxVal, + min = minVal, nullCount = Some(0), avgLen = Some(LongType.defaultSize), maxLen = Some(LongType.defaultSize), @@ -1940,6 +1957,16 @@ case class DeduplicateWithinWatermark(keys: Seq[Attribute], child: LogicalPlan) */ trait SupportsSubquery extends LogicalPlan +/** + * Trait that logical plans can extend to check whether it can allow non-deterministic + * expressions and pass the CheckAnalysis rule. + */ +trait SupportsNonDeterministicExpression extends LogicalPlan { + + /** Returns whether it allows non-deterministic expressions. */ + def allowNonDeterministicExpression: Boolean +} + /** * Collect arbitrary (named) metrics from a dataset. As soon as the query reaches a completion * point (batch query completes or streaming query epoch completes) an event is emitted on the @@ -1952,7 +1979,8 @@ trait SupportsSubquery extends LogicalPlan case class CollectMetrics( name: String, metrics: Seq[NamedExpression], - child: LogicalPlan) + child: LogicalPlan, + dataframeId: Long) extends UnaryNode { override lazy val resolved: Boolean = { @@ -1999,6 +2027,8 @@ case class LateralJoin( joinType: JoinType, condition: Option[Expression]) extends UnaryNode { + override lazy val allAttributes: AttributeSeq = left.output ++ right.plan.output + require(Seq(Inner, LeftOuter, Cross).contains(joinType), s"Unsupported lateral join type $joinType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 739ffa487e393..d7669ac0b1d78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -62,10 +62,11 @@ trait V2WriteCommand extends UnaryCommand with KeepAnalyzedQuery { table.skipSchemaResolution || (query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => + val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) // names and types must match, nullability must be compatible inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && (outAttr.nullable || !inAttr.nullable) }) } @@ -757,6 +758,21 @@ case class MergeIntoTable( copy(targetTable = newLeft, sourceTable = newRight) } +object MergeIntoTable { + def getWritePrivileges( + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = { + val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege] + (matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach { + case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE) + case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE) + case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT) + } + privileges.toSeq + } +} + sealed abstract class MergeAction extends Expression with Unevaluable { def condition: Option[Expression] override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d2f9e9b5d5bf5..211b5a05eb70c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -258,18 +258,8 @@ case object SinglePartition extends Partitioning { SinglePartitionShuffleSpec } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. - * - * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires - * stateful operators to retain the same physical partitioning during the lifetime of the query - * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged - * across Spark versions. Violation of this requirement may bring silent correctness issue. - */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { +trait HashPartitioningLike extends Expression with Partitioning with Unevaluable { + def expressions: Seq[Expression] override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -294,6 +284,20 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } } } +} + +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + * + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. + */ +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends HashPartitioningLike { override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = HashShuffleSpec(this, distribution) @@ -308,6 +312,27 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } +case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) + +/** + * Represents a partitioning where partitions have been coalesced from a HashPartitioning into a + * fewer number of partitions. + */ +case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[CoalescedBoundary]) + extends HashPartitioningLike { + + override def expressions: Seq[Expression] = from.expressions + + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = + CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions) + + override val numPartitions: Int = partitions.length + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning = + copy(from = from.copy(expressions = newChildren)) +} + /** * Represents a partitioning where rows are split across partitions based on transforms defined * by `expressions`. `partitionValuesOpt`, if defined, should contain value of partition key(s) in @@ -661,6 +686,26 @@ case class HashShuffleSpec( override def numPartitions: Int = partitioning.numPartitions } +case class CoalescedHashShuffleSpec( + from: ShuffleSpec, + partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { + + override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { + case SinglePartitionShuffleSpec => + numPartitions == 1 + case CoalescedHashShuffleSpec(otherParent, otherPartitions) => + partitions == otherPartitions && from.isCompatibleWith(otherParent) + case ShuffleSpecCollection(specs) => + specs.exists(isCompatibleWith) + case _ => + false + } + + override def canCreatePartitioning: Boolean = false + + override def numPartitions: Int = partitions.length +} + case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, distribution: ClusteredDistribution) extends ShuffleSpec { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 9d29ca1f9c6e1..c16b50a2b17a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -143,7 +143,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { override val maxIterationsSetting: String = null) extends Strategy /** A batch of rules. */ - protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) + protected[catalyst] case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected def batches: Seq[Batch] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index caf679f3e7a7a..96f78d251c39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -166,6 +166,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps" :: "org.apache.spark.sql.catalyst.optimizer.TransposeWindow" :: + "org.apache.spark.sql.catalyst.optimizer.EvalInlineTables" :: "org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison" :: Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9e605a45414be..82228a5b2aafd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -1030,10 +1030,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] append(str) append("\n") - if (innerChildren.nonEmpty) { + val innerChildrenLocal = innerChildren + if (innerChildrenLocal.nonEmpty) { lastChildren.add(children.isEmpty) lastChildren.add(false) - innerChildren.init.foreach(_.generateTreeString( + innerChildrenLocal.init.foreach(_.generateTreeString( depth + 2, lastChildren, append, verbose, addSuffix = addSuffix, maxFields = maxFields, printNodeId = printNodeId, indent = indent)) lastChildren.remove(lastChildren.size() - 1) @@ -1041,7 +1042,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] lastChildren.add(children.isEmpty) lastChildren.add(true) - innerChildren.last.generateTreeString( + innerChildrenLocal.last.generateTreeString( depth + 2, lastChildren, append, verbose, addSuffix = addSuffix, maxFields = maxFields, printNodeId = printNodeId, indent = indent) lastChildren.remove(lastChildren.size() - 1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index b806ebbed52d0..ce8f5951839e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -53,6 +53,7 @@ object TreePattern extends Enumeration { val IF: Value = Value val IN: Value = Value val IN_SUBQUERY: Value = Value + val INLINE_TABLE_EVAL: Value = Value val INSET: Value = Value val INTERSECT: Value = Value val INVOKE: Value = Value @@ -105,6 +106,7 @@ object TreePattern extends Enumeration { val AS_OF_JOIN: Value = Value val COMMAND: Value = Value val CTE: Value = Value + val DF_DROP_COLUMNS: Value = Value val DISTINCT_LIKE: Value = Value val EVAL_PYTHON_UDF: Value = Value val EVAL_PYTHON_UDTF: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index d392557e650e3..0f60f8d7b1ef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -239,8 +239,7 @@ case class PhysicalMapType(keyType: DataType, valueType: DataType, valueContains class PhysicalNullType() extends PhysicalDataType with PhysicalPrimitiveType { override private[sql] def ordering = - throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( - "PhysicalNullType") + implicitly[Ordering[Unit]].asInstanceOf[Ordering[Any]] override private[sql] type InternalType = Any @transient private[sql] lazy val tag = typeTag[InternalType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index b9d83d444909d..87982e7a5f0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.types._ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { - private val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING" + // visible for testing + private[sql] val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING" /** * Replaces CharType/VarcharType with StringType recursively in the given struct type. If a @@ -237,14 +238,14 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { * attributes. When comparing two char type columns/fields, we need to pad the shorter one to * the longer length. */ - def addPaddingInStringComparison(attrs: Seq[Attribute]): Seq[Expression] = { + def addPaddingInStringComparison(attrs: Seq[Attribute], alwaysPad: Boolean): Seq[Expression] = { val rawTypes = attrs.map(attr => getRawType(attr.metadata)) if (rawTypes.exists(_.isEmpty)) { attrs } else { val typeWithTargetCharLength = rawTypes.map(_.get).reduce(typeWithWiderCharLength) attrs.zip(rawTypes.map(_.get)).map { case (attr, rawType) => - padCharToTargetLength(attr, rawType, typeWithTargetCharLength).getOrElse(attr) + padCharToTargetLength(attr, rawType, typeWithTargetCharLength, alwaysPad).getOrElse(attr) } } } @@ -267,9 +268,10 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { private def padCharToTargetLength( expr: Expression, rawType: DataType, - typeWithTargetCharLength: DataType): Option[Expression] = { + typeWithTargetCharLength: DataType, + alwaysPad: Boolean): Option[Expression] = { (rawType, typeWithTargetCharLength) match { - case (CharType(len), CharType(target)) if target > len => + case (CharType(len), CharType(target)) if alwaysPad || target > len => Some(StringRPad(expr, Literal(target))) case (StructType(fields), StructType(targets)) => @@ -280,7 +282,8 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { while (i < fields.length) { val field = fields(i) val fieldExpr = GetStructField(expr, i, Some(field.name)) - val padded = padCharToTargetLength(fieldExpr, field.dataType, targets(i).dataType) + val padded = padCharToTargetLength( + fieldExpr, field.dataType, targets(i).dataType, alwaysPad) needPadding = padded.isDefined createStructExprs += Literal(field.name) createStructExprs += padded.getOrElse(fieldExpr) @@ -290,7 +293,7 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { case (ArrayType(et, containsNull), ArrayType(target, _)) => val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) - padCharToTargetLength(param, et, target).map { padded => + padCharToTargetLength(param, et, target, alwaysPad).map { padded => val func = LambdaFunction(padded, Seq(param)) ArrayTransform(expr, func) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index e051cfc37f12d..4d90007400ea7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -107,25 +107,30 @@ object IntervalUtils extends SparkIntervalUtils { fallBackNotice: Option[String] = None) = { throw new IllegalArgumentException( s"Interval string does not match $intervalStr format of " + - s"${supportedFormat((startFiled, endField)).map(format => s"`$format`").mkString(", ")} " + + s"${supportedFormat((intervalStr, startFiled, endField)) + .map(format => s"`$format`").mkString(", ")} " + s"when cast to $typeName: ${input.toString}" + s"${fallBackNotice.map(s => s", $s").getOrElse("")}") } val supportedFormat = Map( - (YM.YEAR, YM.MONTH) -> Seq("[+|-]y-m", "INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH"), - (YM.YEAR, YM.YEAR) -> Seq("[+|-]y", "INTERVAL [+|-]'[+|-]y' YEAR"), - (YM.MONTH, YM.MONTH) -> Seq("[+|-]m", "INTERVAL [+|-]'[+|-]m' MONTH"), - (DT.DAY, DT.DAY) -> Seq("[+|-]d", "INTERVAL [+|-]'[+|-]d' DAY"), - (DT.DAY, DT.HOUR) -> Seq("[+|-]d h", "INTERVAL [+|-]'[+|-]d h' DAY TO HOUR"), - (DT.DAY, DT.MINUTE) -> Seq("[+|-]d h:m", "INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE"), - (DT.DAY, DT.SECOND) -> Seq("[+|-]d h:m:s.n", "INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND"), - (DT.HOUR, DT.HOUR) -> Seq("[+|-]h", "INTERVAL [+|-]'[+|-]h' HOUR"), - (DT.HOUR, DT.MINUTE) -> Seq("[+|-]h:m", "INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE"), - (DT.HOUR, DT.SECOND) -> Seq("[+|-]h:m:s.n", "INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND"), - (DT.MINUTE, DT.MINUTE) -> Seq("[+|-]m", "INTERVAL [+|-]'[+|-]m' MINUTE"), - (DT.MINUTE, DT.SECOND) -> Seq("[+|-]m:s.n", "INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND"), - (DT.SECOND, DT.SECOND) -> Seq("[+|-]s.n", "INTERVAL [+|-]'[+|-]s.n' SECOND") + ("year-month", YM.YEAR, YM.MONTH) -> Seq("[+|-]y-m", "INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH"), + ("year-month", YM.YEAR, YM.YEAR) -> Seq("[+|-]y", "INTERVAL [+|-]'[+|-]y' YEAR"), + ("year-month", YM.MONTH, YM.MONTH) -> Seq("[+|-]m", "INTERVAL [+|-]'[+|-]m' MONTH"), + ("day-time", DT.DAY, DT.DAY) -> Seq("[+|-]d", "INTERVAL [+|-]'[+|-]d' DAY"), + ("day-time", DT.DAY, DT.HOUR) -> Seq("[+|-]d h", "INTERVAL [+|-]'[+|-]d h' DAY TO HOUR"), + ("day-time", DT.DAY, DT.MINUTE) -> + Seq("[+|-]d h:m", "INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE"), + ("day-time", DT.DAY, DT.SECOND) -> + Seq("[+|-]d h:m:s.n", "INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND"), + ("day-time", DT.HOUR, DT.HOUR) -> Seq("[+|-]h", "INTERVAL [+|-]'[+|-]h' HOUR"), + ("day-time", DT.HOUR, DT.MINUTE) -> Seq("[+|-]h:m", "INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE"), + ("day-time", DT.HOUR, DT.SECOND) -> + Seq("[+|-]h:m:s.n", "INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND"), + ("day-time", DT.MINUTE, DT.MINUTE) -> Seq("[+|-]m", "INTERVAL [+|-]'[+|-]m' MINUTE"), + ("day-time", DT.MINUTE, DT.SECOND) -> + Seq("[+|-]m:s.n", "INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND"), + ("day-time", DT.SECOND, DT.SECOND) -> Seq("[+|-]s.n", "INTERVAL [+|-]'[+|-]s.n' SECOND") ) def castStringToYMInterval( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 59765cde1f926..06d3910311b1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -23,6 +23,13 @@ import org.apache.spark.unsafe.types.UTF8String object NumberConverter { + /** + * The output string has a max length of one char per bit in the 64-bit `Long` intermediate + * representation plus one char for the '-' sign. This happens in practice when converting + * `Long.MinValue` with `toBase` equal to -2. + */ + private final val MAX_OUTPUT_LENGTH = java.lang.Long.SIZE + 1 + /** * Decode v into value[]. * @@ -148,7 +155,7 @@ object NumberConverter { var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) // Copy the digits in the right side of the array - val temp = new Array[Byte](Math.max(n.length, 64)) + val temp = new Array[Byte](Math.max(n.length, MAX_OUTPUT_LENGTH)) var v: Long = -1 System.arraycopy(n, first, temp, temp.length - n.length + first, n.length - first) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 50ff3eeab0c16..0d947258e6555 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.{Literal => ExprLiteral} -import org.apache.spark.sql.catalyst.optimizer.ConstantFolding +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, Optimizer} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION @@ -279,12 +279,15 @@ object ResolveDefaultColumns extends QueryErrorsBase with ResolveDefaultColumnsU throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( statementType, colName, defaultSQL) } + // Analyze the parse result. val plan = try { val analyzer: Analyzer = DefaultColumnAnalyzer val analyzed = analyzer.execute(Project(Seq(Alias(parsed, colName)()), OneRowRelation())) analyzer.checkAnalysis(analyzed) - ConstantFolding(analyzed) + // Eagerly execute finish-analysis and constant-folding rules before checking whether the + // expression is foldable and resolved. + ConstantFolding(DefaultColumnOptimizer.FinishAnalysis(analyzed)) } catch { case ex: AnalysisException => throw QueryCompilationErrors.defaultValuesUnresolvedExprError( @@ -293,6 +296,21 @@ object ResolveDefaultColumns extends QueryErrorsBase with ResolveDefaultColumnsU val analyzed: Expression = plan.collectFirst { case Project(Seq(a: Alias), OneRowRelation()) => a.child }.get + + if (!analyzed.foldable) { + throw QueryCompilationErrors.defaultValueNotConstantError(statementType, colName, defaultSQL) + } + + // Another extra check, expressions should already be resolved if AnalysisException is not + // thrown in the code block above + if (!analyzed.resolved) { + throw QueryCompilationErrors.defaultValuesUnresolvedExprError( + statementType, + colName, + defaultSQL, + cause = null) + } + // Perform implicit coercion from the provided expression type to the required column type. if (dataType == analyzed.dataType) { analyzed @@ -436,6 +454,11 @@ object ResolveDefaultColumns extends QueryErrorsBase with ResolveDefaultColumnsU new CatalogManager(BuiltInFunctionCatalog, BuiltInFunctionCatalog.v1Catalog)) { } + /** + * This is an Optimizer for convert default column expressions to foldable literals. + */ + object DefaultColumnOptimizer extends Optimizer(DefaultColumnAnalyzer.catalogManager) + /** * This is a FunctionCatalog for performing analysis using built-in functions only. It is a helper * for the DefaultColumnAnalyzer above. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index cf9dd7fdf4767..8080af2fb6b51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -103,19 +104,31 @@ class CatalogManager( } } - def setCurrentNamespace(namespace: Array[String]): Unit = synchronized { + private def assertNamespaceExist(namespace: Array[String]): Unit = { currentCatalog match { - case _ if isSessionCatalog(currentCatalog) && namespace.length == 1 => - v1SessionCatalog.setCurrentDatabase(namespace.head) - case _ if isSessionCatalog(currentCatalog) => - throw QueryCompilationErrors.noSuchNamespaceError(namespace) case catalog: SupportsNamespaces if !catalog.namespaceExists(namespace) => - throw QueryCompilationErrors.noSuchNamespaceError(namespace) + throw QueryCompilationErrors.noSuchNamespaceError(catalog.name() +: namespace) case _ => - _currentNamespace = Some(namespace) } } + def setCurrentNamespace(namespace: Array[String]): Unit = synchronized { + if (isSessionCatalog(currentCatalog) && namespace.length == 1) { + v1SessionCatalog.setCurrentDatabaseWithNameCheck( + namespace.head, + name => { + currentCatalog match { + case catalog: SupportsNamespaces if !catalog.namespaceExists(namespace) => + throw new NoSuchDatabaseException(name) + case _ => + } + }) + } else { + assertNamespaceExist(namespace) + } + _currentNamespace = Some(namespace) + } + private var _currentCatalogName: Option[String] = None def currentCatalog: CatalogPlugin = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index be569b1de9dbc..f8f682e76cfc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -331,9 +331,10 @@ private[sql] object CatalogV2Util { def loadTable( catalog: CatalogPlugin, ident: Identifier, - timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] = + timeTravelSpec: Option[TimeTravelSpec] = None, + writePrivilegesString: Option[String] = None): Option[Table] = try { - Option(getTable(catalog, ident, timeTravelSpec)) + Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString)) } catch { case _: NoSuchTableException => None case _: NoSuchDatabaseException => None @@ -343,8 +344,10 @@ private[sql] object CatalogV2Util { def getTable( catalog: CatalogPlugin, ident: Identifier, - timeTravelSpec: Option[TimeTravelSpec] = None): Table = { + timeTravelSpec: Option[TimeTravelSpec] = None, + writePrivilegesString: Option[String] = None): Table = { if (timeTravelSpec.nonEmpty) { + assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel") timeTravelSpec.get match { case v: AsOfVersion => catalog.asTableCatalog.loadTable(ident, v.version) @@ -352,7 +355,13 @@ private[sql] object CatalogV2Util { catalog.asTableCatalog.loadTable(ident, ts.timestamp) } } else { - catalog.asTableCatalog.loadTable(ident) + if (writePrivilegesString.isDefined) { + val writePrivileges = writePrivilegesString.get.split(",").map(_.trim) + .map(TableWritePrivilege.valueOf).toSet.asJava + catalog.asTableCatalog.loadTable(ident, writePrivileges) + } else { + catalog.asTableCatalog.loadTable(ident) + } } } @@ -512,10 +521,15 @@ private[sql] object CatalogV2Util { } if (isDefaultColumn) { - val e = analyze(f, EXISTS_DEFAULT_COLUMN_METADATA_KEY) + val e = analyze( + f, + statementType = "Column analysis", + metadataKey = EXISTS_DEFAULT_COLUMN_METADATA_KEY) + assert(e.resolved && e.foldable, "The existence default value must be a simple SQL string that is resolved and foldable, " + "but got: " + f.getExistenceDefaultValue().get) + val defaultValue = new ColumnDefaultValue( f.getCurrentDefaultValue().get, LiteralValue(e.eval(), f.dataType)) val cleanedMetadata = metadataWithKeysRemoved( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala index da201e816497c..8928ba57f06c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala @@ -22,7 +22,7 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.connector.catalog.V1Table.addV2TableProperties import org.apache.spark.sql.connector.expressions.{LogicalExpressions, Transform} @@ -38,7 +38,7 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { lazy val options: Map[String, String] = { v1Table.storage.locationUri match { case Some(uri) => - v1Table.storage.properties + ("path" -> uri.toString) + v1Table.storage.properties + ("path" -> CatalogUtils.URIToString(uri)) case _ => v1Table.storage.properties } @@ -81,7 +81,9 @@ private[sql] object V1Table { TableCatalog.OPTION_PREFIX + key -> value } ++ v1Table.provider.map(TableCatalog.PROP_PROVIDER -> _) ++ v1Table.comment.map(TableCatalog.PROP_COMMENT -> _) ++ - v1Table.storage.locationUri.map(TableCatalog.PROP_LOCATION -> _.toString) ++ + v1Table.storage.locationUri.map { loc => + TableCatalog.PROP_LOCATION -> CatalogUtils.URIToString(loc) + } ++ (if (managed) Some(TableCatalog.PROP_IS_MANAGED_LOCATION -> "true") else None) ++ (if (external) Some(TableCatalog.PROP_EXTERNAL -> "true") else None) ++ Some(TableCatalog.PROP_OWNER -> v1Table.owner) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index fbd2520e2a774..7f536bdb712a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions +import org.apache.commons.lang3.StringUtils + import org.apache.spark.SparkException import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -350,7 +352,7 @@ private[sql] object HoursTransform { private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { override def toString: String = { if (dataType.isInstanceOf[StringType]) { - s"'$value'" + s"'${StringUtils.replace(s"$value", "'", "''")}'" } else { s"$value" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 14882a7006173..ec58298babdb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -248,11 +248,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("expression" -> toSQLExpr(trimmedNestedGenerator))) } - def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = { + def moreThanOneGeneratorError(generators: Seq[Expression]): Throwable = { new AnalysisException( errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", messageParameters = Map( - "clause" -> clause, "num" -> generators.size.toString, "generators" -> generators.map(toSQLExpr).mkString(", "))) } @@ -3407,6 +3406,19 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "defaultValue" -> defaultValue)) } + def defaultValueNotConstantError( + statement: String, + colName: String, + defaultValue: String): Throwable = { + new AnalysisException( + errorClass = "INVALID_DEFAULT_VALUE.NOT_CONSTANT", + messageParameters = Map( + "statement" -> toSQLStmt(statement), + "colName" -> toSQLId(colName), + "defaultValue" -> defaultValue + )) + } + def nullableColumnOrFieldError(name: Seq[String]): Throwable = { new AnalysisException( errorClass = "NULLABLE_COLUMN_OR_FIELD", @@ -3672,6 +3684,22 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def avroIncompatibleReadError( + avroPath: String, + sqlPath: String, + avroType: String, + sqlType: String): Throwable = { + new AnalysisException( + errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + messageParameters = Map( + "avroPath" -> avroPath, + "sqlPath" -> sqlPath, + "avroType" -> avroType, + "sqlType" -> toSQLType(sqlType) + ) + ) + } + def optionMustBeLiteralString(key: String): Throwable = { new AnalysisException( errorClass = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index f4968cd005708..798146839464c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -372,10 +372,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = e) } - def illegalUrlError(url: UTF8String): Throwable = { + def illegalUrlError(url: UTF8String, e: IllegalArgumentException): Throwable = { new SparkIllegalArgumentException( errorClass = "CANNOT_DECODE_URL", - messageParameters = Map("url" -> url.toString) + messageParameters = Map("url" -> url.toString), + cause = e ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e0ce7c4dea9d..6f2f0088fccd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -181,7 +181,7 @@ object SQLConf { // Make sure SqlApiConf is always in sync with SQLConf. SqlApiConf will always try to // load SqlConf to make sure both classes are in sync from the get go. - SqlApiConf.setConfGetter(() => SQLConf.get) + SqlApiConfHelper.setConfGetter(() => SQLConf.get) /** * Returns the active config object within the current scope. If there is an active SparkSession, @@ -657,6 +657,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS = + buildConf("spark.sql.adaptive.applyFinalStageShuffleOptimizations") + .internal() + .doc("Configures whether adaptive query execution (if enabled) should apply shuffle " + + "coalescing and local shuffle read optimization for the final query stage.") + .version("3.4.2") + .booleanConf + .createWithDefault(true) + val ADAPTIVE_EXECUTION_LOG_LEVEL = buildConf("spark.sql.adaptive.logLevel") .internal() .doc("Configures the log level for adaptive execution logging of plan changes. The value " + @@ -885,7 +894,7 @@ object SQLConf { .booleanConf .createWithDefault(false) - val CASE_SENSITIVE = buildConf(SqlApiConf.CASE_SENSITIVE_KEY) + val CASE_SENSITIVE = buildConf(SqlApiConfHelper.CASE_SENSITIVE_KEY) .internal() .doc("Whether the query analyzer should be case sensitive or not. " + "Default to case insensitive. It is highly discouraged to turn on case sensitive mode.") @@ -995,12 +1004,22 @@ object SQLConf { "`parquet.compression` is specified in the table-specific options/properties, the " + "precedence would be `compression`, `parquet.compression`, " + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + - "snappy, gzip, lzo, brotli, lz4, lz4raw, zstd.") + "snappy, gzip, lzo, brotli, lz4, lz4raw, lz4_raw, zstd.") .version("1.1.1") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues( - Set("none", "uncompressed", "snappy", "gzip", "lzo", "brotli", "lz4", "lz4raw", "zstd")) + Set( + "none", + "uncompressed", + "snappy", + "gzip", + "lzo", + "brotli", + "lz4", + "lz4raw", + "lz4_raw", + "zstd")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -1519,7 +1538,7 @@ object SQLConf { .doc("Whether to forcibly enable some optimization rules that can change the output " + "partitioning of a cached query when executing it for caching. If it is set to true, " + "queries may need an extra shuffle to read the cached data. This configuration is " + - "disabled by default. Currently, the optimization rules enabled by this configuration " + + "enabled by default. The optimization rules enabled by this configuration " + s"are ${ADAPTIVE_EXECUTION_ENABLED.key} and ${AUTO_BUCKETED_SCAN_ENABLED.key}.") .version("3.2.0") .booleanConf @@ -2101,7 +2120,9 @@ object SQLConf { buildConf("spark.sql.streaming.stateStore.skipNullsForStreamStreamJoins.enabled") .internal() .doc("When true, this config will skip null values in hash based stream-stream joins. " + - "The number of skipped null values will be shown as custom metric of stream join operator.") + "The number of skipped null values will be shown as custom metric of stream join operator. " + + "If the streaming query was started with Spark 3.5 or above, please exercise caution " + + "before enabling this config since it may hide potential data loss/corruption issues.") .version("3.3.0") .booleanConf .createWithDefault(false) @@ -2123,6 +2144,17 @@ object SQLConf { .createWithDefault(true) + val STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED = + buildConf("spark.sql.streaming.optimizeOneRowPlan.enabled") + .internal() + .doc("When true, enable OptimizeOneRowPlan rule for the case where the child is a " + + "streaming Dataset. This is a fallback flag to revert the 'incorrect' behavior, hence " + + "this configuration must not be used without understanding in depth. Use this only to " + + "quickly recover failure in existing query!") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val VARIABLE_SUBSTITUTE_ENABLED = buildConf("spark.sql.variable.substitute") .doc("This enables substitution using syntax like `${var}`, `${system:var}`, " + @@ -2657,7 +2689,7 @@ object SQLConf { Try { DateTimeUtils.getZoneId(zone) }.isSuccess } - val SESSION_LOCAL_TIMEZONE = buildConf(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY) + val SESSION_LOCAL_TIMEZONE = buildConf(SqlApiConfHelper.SESSION_LOCAL_TIMEZONE_KEY) .doc("The ID of session local timezone in the format of either region-based zone IDs or " + "zone offsets. Region IDs must have the form 'area/city', such as 'America/Los_Angeles'. " + "Zone offsets must be in the format '(+|-)HH', '(+|-)HH:mm' or '(+|-)HH:mm:ss', e.g '-08', " + @@ -2859,7 +2891,9 @@ object SQLConf { val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") .doc("When using Apache Arrow, limit the maximum number of records that can be written " + - "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") + "to a single ArrowRecordBatch in memory. This configuration is not effective for the " + + "grouping API such as DataFrame(.cogroup).groupby.applyInPandas because each group " + + "becomes each ArrowRecordBatch. If set to zero or negative there is no limit.") .version("2.3.0") .intConf .createWithDefault(10000) @@ -3161,7 +3195,7 @@ object SQLConf { .checkValues(StoreAssignmentPolicy.values.map(_.toString)) .createWithDefault(StoreAssignmentPolicy.ANSI.toString) - val ANSI_ENABLED = buildConf(SqlApiConf.ANSI_ENABLED_KEY) + val ANSI_ENABLED = buildConf(SqlApiConfHelper.ANSI_ENABLED_KEY) .doc("When true, Spark SQL uses an ANSI compliant dialect instead of being Hive compliant. " + "For example, Spark will throw an exception at runtime instead of returning null results " + "when the inputs to a SQL operator/function are invalid." + @@ -3195,6 +3229,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CHUNK_BASE64_STRING_ENABLED = buildConf("spark.sql.chunkBase64String.enabled") + .doc("Whether to truncate string generated by the `Base64` function. When true, base64" + + " strings generated by the base64 function are chunked into lines of at most 76" + + " characters. When false, the base64 strings are not chunked.") + .version("3.5.2") + .booleanConf + .createWithDefault(true) + val ENABLE_DEFAULT_COLUMNS = buildConf("spark.sql.defaultColumn.enabled") .internal() @@ -3378,6 +3420,15 @@ object SQLConf { .intConf .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + val PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN = + buildConf("spark.databricks.sql.optimizer.pruneFiltersCanPruneStreamingSubplan") + .internal() + .doc("Allow PruneFilters to remove streaming subplans when we encounter a false filter. " + + "This flag is to restore prior buggy behavior for broken pipelines.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -3760,7 +3811,7 @@ object SQLConf { .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) - val LEGACY_TIME_PARSER_POLICY = buildConf(SqlApiConf.LEGACY_TIME_PARSER_POLICY_KEY) + val LEGACY_TIME_PARSER_POLICY = buildConf(SqlApiConfHelper.LEGACY_TIME_PARSER_POLICY_KEY) .internal() .doc("When LEGACY, java.text.SimpleDateFormat is used for formatting and parsing " + "dates/timestamps in a locale-sensitive manner, which is the approach before Spark 3.0. " + @@ -4141,6 +4192,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LEGACY_NO_CHAR_PADDING_IN_PREDICATE = buildConf("spark.sql.legacy.noCharPaddingInPredicate") + .internal() + .doc("When true, Spark will not apply char type padding for CHAR type columns in string " + + s"comparison predicates, when '${READ_SIDE_CHAR_PADDING.key}' is false.") + .version("3.5.2") + .booleanConf + .createWithDefault(false) + val CLI_PRINT_HEADER = buildConf("spark.sql.cli.print.header") .doc("When set to true, spark-sql CLI prints the names of the columns in query output.") @@ -4251,6 +4310,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA = + buildConf("spark.sql.legacy.avro.allowIncompatibleSchema") + .internal() + .doc("When set to false, if types in Avro are encoded in the same format, but " + + "the type in the Avro schema explicitly says that the data types are different, " + + "reject reading the data type in the format to avoid returning incorrect results. " + + "When set to true, it restores the legacy behavior of allow reading the data in the" + + " format, which may return incorrect results.") + .version("3.5.1") + .booleanConf + .createWithDefault(false) + val LEGACY_NON_IDENTIFIER_OUTPUT_CATALOG_NAME = buildConf("spark.sql.legacy.v1IdentifierNoCatalog") .internal() @@ -4313,7 +4384,7 @@ object SQLConf { .createWithDefault(false) val LOCAL_RELATION_CACHE_THRESHOLD = - buildConf(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY) + buildConf(SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY) .doc("The threshold for the size in bytes of local relations to be cached at " + "the driver side after serialization.") .version("3.5.0") @@ -4321,6 +4392,15 @@ object SQLConf { .checkValue(_ >= 0, "The threshold of cached local relations must not be negative") .createWithDefault(64 * 1024 * 1024) + val EXCLUDE_SUBQUERY_EXP_REFS_FROM_REMOVE_REDUNDANT_ALIASES = + buildConf("spark.sql.optimizer.excludeSubqueryRefsFromRemoveRedundantAliases.enabled") + .internal() + .doc("When true, exclude the references from the subquery expressions (in, exists, etc.) " + + s"while removing redundant aliases.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation") .internal() .doc("If true, the old bogus percentile_disc calculation is used. The old calculation " + @@ -5048,6 +5128,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def ansiRelationPrecedence: Boolean = ansiEnabled && getConf(ANSI_RELATION_PRECEDENCE) + def chunkBase64StringEnabled: Boolean = getConf(CHUNK_BASE64_STRING_ENABLED) + def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match { case "TIMESTAMP_LTZ" => // For historical reason, the TimestampType maps to TIMESTAMP WITH LOCAL TIME ZONE diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaTypeInferenceBeans.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaTypeInferenceBeans.java new file mode 100644 index 0000000000000..cc3540717ee7d --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/JavaTypeInferenceBeans.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst; + +public class JavaTypeInferenceBeans { + + static class JavaBeanWithGenericsA { + public T getPropertyA() { + return null; + } + + public void setPropertyA(T a) { + + } + } + + static class JavaBeanWithGenericsAB extends JavaBeanWithGenericsA { + public T getPropertyB() { + return null; + } + + public void setPropertyB(T a) { + + } + } + + static class JavaBeanWithGenericsABC extends JavaBeanWithGenericsAB { + public T getPropertyC() { + return null; + } + + public void setPropertyC(T a) { + + } + } + + static class JavaBeanWithGenerics { + private A attribute; + + private T value; + + public A getAttribute() { + return attribute; + } + + public void setAttribute(A attribute) { + this.attribute = attribute; + } + + public T getValue() { + return value; + } + + public void setValue(T value) { + this.value = value; + } + } + + static class JavaBeanWithGenericBase extends JavaBeanWithGenerics { + + } + + static class JavaBeanWithGenericHierarchy extends JavaBeanWithGenericsABC { + + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index a924a9ed02e5d..7cb4d5f123253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal, Murmur3Hash, Pmod} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Murmur3Hash, Pmod} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType @@ -146,63 +146,75 @@ class DistributionSuite extends SparkFunSuite { false) } - test("HashPartitioning is the output partitioning") { - // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of - // the required clustering expressions. - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c")), - true) - - checkSatisfied( - HashPartitioning(Seq($"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c")), - true) - - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"b", $"c")), - false) - - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"d", $"e")), - false) - - // When ClusteredDistribution.requireAllClusterKeys is set to true, - // HashPartitioning can only satisfy ClusteredDistribution iff its hash expressions are - // exactly same as the required clustering expressions. - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), - true) - - checkSatisfied( - HashPartitioning(Seq($"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), - false) - - checkSatisfied( - HashPartitioning(Seq($"b", $"a", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), - false) - - // HashPartitioning cannot satisfy OrderedDistribution - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 10), - OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), - false) + private def testHashPartitioningLike( + partitioningName: String, + create: (Seq[Expression], Int) => Partitioning): Unit = { + + test(s"$partitioningName is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c")), + true) + + checkSatisfied( + create(Seq($"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c")), + true) + + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"b", $"c")), + false) + + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"d", $"e")), + false) + + // When ClusteredDistribution.requireAllClusterKeys is set to true, + // HashPartitioning can only satisfy ClusteredDistribution iff its hash expressions are + // exactly same as the required clustering expressions. + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + true) + + checkSatisfied( + create(Seq($"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + checkSatisfied( + create(Seq($"b", $"a", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + // HashPartitioning cannot satisfy OrderedDistribution + checkSatisfied( + create(Seq($"a", $"b", $"c"), 10), + OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), + false) + + checkSatisfied( + create(Seq($"a", $"b", $"c"), 1), + OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), + false) // TODO: this can be relaxed. + + checkSatisfied( + create(Seq($"b", $"c"), 10), + OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), + false) + } + } - checkSatisfied( - HashPartitioning(Seq($"a", $"b", $"c"), 1), - OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), - false) // TODO: this can be relaxed. + testHashPartitioningLike("HashPartitioning", + (expressions, numPartitions) => HashPartitioning(expressions, numPartitions)) - checkSatisfied( - HashPartitioning(Seq($"b", $"c"), 10), - OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)), - false) - } + testHashPartitioningLike("CoalescedHashPartitioning", (expressions, numPartitions) => + CoalescedHashPartitioning( + HashPartitioning(expressions, numPartitions), Seq(CoalescedBoundary(0, numPartitions)))) test("RangePartitioning is the output partitioning") { // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala index 6439997609766..f7c1043d1cb8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala @@ -24,6 +24,7 @@ import scala.beans.{BeanProperty, BooleanBeanProperty} import scala.reflect.{classTag, ClassTag} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.JavaTypeInferenceBeans.{JavaBeanWithGenericBase, JavaBeanWithGenericHierarchy, JavaBeanWithGenericsABC} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, UDTCaseClass, UDTForCaseClass} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.types.{DecimalType, MapType, Metadata, StringType, StructField, StructType} @@ -66,7 +67,8 @@ class LeafBean { @BeanProperty var period: java.time.Period = _ @BeanProperty var enum: java.time.Month = _ @BeanProperty val readOnlyString = "read-only" - @BeanProperty var genericNestedBean: JavaBeanWithGenerics[String, String] = _ + @BeanProperty var genericNestedBean: JavaBeanWithGenericBase = _ + @BeanProperty var genericNestedBean2: JavaBeanWithGenericsABC[Integer] = _ var nonNullString: String = "value" @javax.annotation.Nonnull @@ -186,8 +188,18 @@ class JavaTypeInferenceSuite extends SparkFunSuite { encoderField("duration", DayTimeIntervalEncoder), encoderField("enum", JavaEnumEncoder(classTag[java.time.Month])), encoderField("genericNestedBean", JavaBeanEncoder( - ClassTag(classOf[JavaBeanWithGenerics[String, String]]), - Seq(encoderField("attribute", StringEncoder), encoderField("value", StringEncoder)))), + ClassTag(classOf[JavaBeanWithGenericBase]), + Seq( + encoderField("attribute", StringEncoder), + encoderField("value", StringEncoder) + ))), + encoderField("genericNestedBean2", JavaBeanEncoder( + ClassTag(classOf[JavaBeanWithGenericsABC[Integer]]), + Seq( + encoderField("propertyA", StringEncoder), + encoderField("propertyB", BoxedLongEncoder), + encoderField("propertyC", BoxedIntEncoder) + ))), encoderField("instant", STRICT_INSTANT_ENCODER), encoderField("localDate", STRICT_LOCAL_DATE_ENCODER), encoderField("localDateTime", LocalDateTimeEncoder), @@ -224,4 +236,27 @@ class JavaTypeInferenceSuite extends SparkFunSuite { )) assert(encoder === expected) } + + test("SPARK-44910: resolve bean with generic base class") { + val encoder = + JavaTypeInference.encoderFor(classOf[JavaBeanWithGenericBase]) + val expected = + JavaBeanEncoder(ClassTag(classOf[JavaBeanWithGenericBase]), Seq( + encoderField("attribute", StringEncoder), + encoderField("value", StringEncoder) + )) + assert(encoder === expected) + } + + test("SPARK-44910: resolve bean with hierarchy of generic classes") { + val encoder = + JavaTypeInference.encoderFor(classOf[JavaBeanWithGenericHierarchy]) + val expected = + JavaBeanEncoder(ClassTag(classOf[JavaBeanWithGenericHierarchy]), Seq( + encoderField("propertyA", StringEncoder), + encoderField("propertyB", BoxedLongEncoder), + encoderField("propertyC", BoxedIntEncoder) + )) + assert(encoder === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index 51e7688732265..6b069d1c97363 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -62,211 +62,254 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { } } - test("compatibility: HashShuffleSpec on both sides") { - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = true - ) - - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), - expected = true - ) + private def testHashShuffleSpecLike( + shuffleSpecName: String, + create: (HashPartitioning, ClusteredDistribution) => ShuffleSpec): Unit = { - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), ClusteredDistribution(Seq($"c", $"d"))), - expected = true - ) + test(s"compatibility: $shuffleSpecName on both sides") { + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"c", $"c", $"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a"), 10), ClusteredDistribution(Seq($"a", $"b"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"d"), 10), - ClusteredDistribution(Seq($"a", $"c", $"d"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"d"), 10), ClusteredDistribution(Seq($"c", $"d"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10), - ClusteredDistribution(Seq($"a", $"c", $"c"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"c", $"c", $"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10), - ClusteredDistribution(Seq($"a", $"c", $"d"))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"d"), 10), + ClusteredDistribution(Seq($"a", $"c", $"d"))), + expected = true + ) - // negative cases - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"c"), 5), - ClusteredDistribution(Seq($"c", $"d"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"c", $"a"), 10), + ClusteredDistribution(Seq($"a", $"c", $"c"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"c", $"a"), 10), + ClusteredDistribution(Seq($"a", $"c", $"d"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + // negative cases + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"c"), 5), + ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10), - ClusteredDistribution(Seq($"a", $"b", $"b"))), - expected = false - ) - } + checkCompatible( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) - test("compatibility: Only one side is HashShuffleSpec") { - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - SinglePartitionShuffleSpec, - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1), - ClusteredDistribution(Seq($"a", $"b"))), - SinglePartitionShuffleSpec, - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + create(HashPartitioning(Seq($"a", $"b", $"a"), 10), + ClusteredDistribution(Seq($"a", $"b", $"b"))), + expected = false + ) + } - checkCompatible( - SinglePartitionShuffleSpec, - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1), - ClusteredDistribution(Seq($"a", $"b"))), - expected = true - ) + test(s"compatibility: Only one side is $shuffleSpecName") { + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + SinglePartitionShuffleSpec, + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 1), + ClusteredDistribution(Seq($"a", $"b"))), + SinglePartitionShuffleSpec, + expected = true + ) - checkCompatible( - RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - expected = false - ) + checkCompatible( + SinglePartitionShuffleSpec, + create(HashPartitioning(Seq($"a", $"b"), 1), + ClusteredDistribution(Seq($"a", $"b"))), + expected = true + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - expected = true - ) + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), + checkCompatible( + RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - expected = true - ) + expected = false + ) - checkCompatible( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a"), 10), + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"))))), - expected = false - ) + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + expected = true + ) - checkCompatible( - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"))), - HashShuffleSpec(HashPartitioning(Seq($"d"), 10), - ClusteredDistribution(Seq($"c", $"d"))))), - expected = true - ) + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + expected = true + ) - checkCompatible( - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"b"), 10), + checkCompatible( + create(HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))), - HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), - ClusteredDistribution(Seq($"a", $"b"))))), - ShuffleSpecCollection(Seq( - HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"))), - HashShuffleSpec(HashPartitioning(Seq($"c"), 10), - ClusteredDistribution(Seq($"c", $"d"))))), - expected = false - ) + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"))))), + expected = false + ) + + checkCompatible( + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"))), + create(HashPartitioning(Seq($"d"), 10), + ClusteredDistribution(Seq($"c", $"d"))))), + expected = true + ) + + checkCompatible( + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))), + create(HashPartitioning(Seq($"a", $"b"), 10), + ClusteredDistribution(Seq($"a", $"b"))))), + ShuffleSpecCollection(Seq( + create(HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"))), + create(HashPartitioning(Seq($"c"), 10), + ClusteredDistribution(Seq($"c", $"d"))))), + expected = false + ) + } + } + + testHashShuffleSpecLike("HashShuffleSpec", + (partitioning, distribution) => HashShuffleSpec(partitioning, distribution)) + testHashShuffleSpecLike("CoalescedHashShuffleSpec", + (partitioning, distribution) => { + val partitions = if (partitioning.numPartitions == 1) { + Seq(CoalescedBoundary(0, 1)) + } else { + Seq(CoalescedBoundary(0, 1), CoalescedBoundary(0, partitioning.numPartitions)) + } + CoalescedHashShuffleSpec(HashShuffleSpec(partitioning, distribution), partitions) + }) + + test("compatibility: CoalescedHashShuffleSpec other specs") { + val hashShuffleSpec = HashShuffleSpec( + HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", $"b"))) + checkCompatible( + hashShuffleSpec, + CoalescedHashShuffleSpec(hashShuffleSpec, Seq(CoalescedBoundary(0, 10))), + expected = false + ) + + checkCompatible( + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))), + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))), + expected = true + ) + + checkCompatible( + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 4), CoalescedBoundary(4, 10))), + CoalescedHashShuffleSpec(hashShuffleSpec, + Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))), + expected = false + ) } test("compatibility: other specs") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index e2e980073307d..6b5f0fe3876da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -118,6 +118,13 @@ case class TestFunctionWithTypeCheckFailure( case class UnresolvedTestPlan() extends UnresolvedLeafNode +case class SupportsNonDeterministicExpressionTestOperator( + actions: Seq[Expression], + allowNonDeterministicExpression: Boolean) + extends LeafNode with SupportsNonDeterministicExpression { + override def output: Seq[Attribute] = Seq() +} + class AnalysisErrorSuite extends AnalysisTest { import TestRelations._ @@ -344,10 +351,39 @@ class AnalysisErrorSuite extends AnalysisTest { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"INT\"")) - errorTest( - "too many generators", - listRelation.select(Explode($"list").as("a"), Explode($"list").as("b")), - "only one generator" :: "explode" :: Nil) + errorClassTest( + "the buckets of ntile window function is not foldable", + testRelation2.select( + WindowExpression( + NTile(Literal(99.9f)), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as("window")), + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "sqlExpr" -> "\"ntile(99.9)\"", + "paramIndex" -> "1", + "inputSql" -> "\"99.9\"", + "inputType" -> "\"FLOAT\"", + "requiredType" -> "\"INT\"")) + + + errorClassTest( + "the buckets of ntile window function is not int literal", + testRelation2.select( + WindowExpression( + NTile(AttributeReference("b", IntegerType)()), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as("window")), + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + messageParameters = Map( + "sqlExpr" -> "\"ntile(b)\"", + "inputName" -> "buckets", + "inputExpr" -> "\"b\"", + "inputType" -> "\"INT\"")) errorClassTest( "unresolved attributes", @@ -754,18 +790,11 @@ class AnalysisErrorSuite extends AnalysisTest { "SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT", Map("limit" -> "1000000000", "offset" -> "2000000000")) - errorTest( - "more than one generators in SELECT", - listRelation.select(Explode($"list"), Explode($"list")), - "The generator is not supported: only one generator allowed per select clause but found 2: " + - """"explode(list)", "explode(list)"""" :: Nil - ) - errorTest( "more than one generators for aggregates in SELECT", testRelation.select(Explode(CreateArray(min($"a") :: Nil)), Explode(CreateArray(max($"a") :: Nil))), - "The generator is not supported: only one generator allowed per select clause but found 2: " + + "The generator is not supported: only one generator allowed per SELECT clause but found 2: " + """"explode(array(min(a)))", "explode(array(max(a)))"""" :: Nil ) @@ -1305,4 +1334,20 @@ class AnalysisErrorSuite extends AnalysisTest { ) } } + + test("SPARK-48871: SupportsNonDeterministicExpression allows non-deterministic expressions") { + val nonDeterministicExpressions = Seq(new Rand()) + val tolerantPlan = + SupportsNonDeterministicExpressionTestOperator( + nonDeterministicExpressions, allowNonDeterministicExpression = true) + assertAnalysisSuccess(tolerantPlan) + + val intolerantPlan = + SupportsNonDeterministicExpressionTestOperator( + nonDeterministicExpressions, allowNonDeterministicExpression = false) + assertAnalysisError( + intolerantPlan, + "INVALID_NON_DETERMINISTIC_EXPRESSIONS" :: Nil + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 06c3e3eb0405a..70a7fdc81e7e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -771,34 +771,35 @@ class AnalysisSuite extends AnalysisTest with Matchers { val literal = Literal(1).as("lit") // Ok - assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation).resolved) + assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation, 0).resolved) // Bad name - assert(!CollectMetrics("", sum :: Nil, testRelation).resolved) + assert(!CollectMetrics("", sum :: Nil, testRelation, 0).resolved) assertAnalysisErrorClass( - CollectMetrics("", sum :: Nil, testRelation), + CollectMetrics("", sum :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME", expectedMessageParameters = Map( - "operator" -> "'CollectMetrics , [sum(a#x) AS sum#xL]\n+- LocalRelation , [a#x]\n") + "operator" -> + "'CollectMetrics , [sum(a#x) AS sum#xL], 0\n+- LocalRelation , [a#x]\n") ) // No columns - assert(!CollectMetrics("evt", Nil, testRelation).resolved) + assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved) def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = { - assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors) + assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), errors) } // Unwrapped attribute assertAnalysisErrorClass( - CollectMetrics("event", a :: Nil, testRelation), + CollectMetrics("event", a :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", expectedMessageParameters = Map("expr" -> "\"a\"") ) // Unwrapped non-deterministic expression assertAnalysisErrorClass( - CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation), + CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", expectedMessageParameters = Map("expr" -> "\"rand(10) AS rnd\"") ) @@ -808,7 +809,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics( "event", Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil, - testRelation), + testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_DISTINCT_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(DISTINCT a) AS sum\"") @@ -819,7 +820,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics( "event", Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil, - testRelation), + testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(sum(a)) AS sum\"") ) @@ -830,7 +831,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { WindowSpecDefinition(Nil, a.asc :: Nil, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) assertAnalysisErrorClass( - CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation), + CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED", expectedMessageParameters = Map( "expr" -> @@ -848,14 +849,14 @@ class AnalysisSuite extends AnalysisTest with Matchers { // Same result - duplicate names are allowed assertAnalysisSuccess(Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", count :: Nil, testRelation) :: Nil)) + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: Nil)) // Same children, structurally different metrics - fail assertAnalysisErrorClass( Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil), + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", sum :: Nil, testRelation, 1) :: Nil), expectedErrorClass = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) @@ -865,17 +866,17 @@ class AnalysisSuite extends AnalysisTest with Matchers { val tblB = LocalRelation(b) assertAnalysisErrorClass( Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", count :: Nil, tblB) :: Nil), + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", count :: Nil, tblB, 1) :: Nil), expectedErrorClass = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) // Subquery different tree - fail - val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation)) + val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation, 0)) val query = Project( b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil, - CollectMetrics("evt1", count :: Nil, tblB)) + CollectMetrics("evt1", count :: Nil, tblB, 1)) assertAnalysisErrorClass( query, expectedErrorClass = "DUPLICATED_METRICS_NAME", @@ -887,7 +888,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { case a: AggregateExpression => a.copy(filter = Some(true)) }.asInstanceOf[NamedExpression] assertAnalysisErrorClass( - CollectMetrics("evt1", sumWithFilter :: Nil, testRelation), + CollectMetrics("evt1", sumWithFilter :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(a) FILTER (WHERE true) AS sum\"") @@ -1516,7 +1517,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-43030: deduplicate relations in CTE relation definitions") { val join = testRelation.as("left").join(testRelation.as("right")) val cteDef = CTERelationDef(join) - val cteRef = CTERelationRef(cteDef.id, false, Nil) + val cteRef = CTERelationRef(cteDef.id, false, Nil, false) withClue("flat CTE") { val plan = WithCTE(cteRef.select($"left.a"), Seq(cteDef)).analyze @@ -1529,7 +1530,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { withClue("nested CTE") { val cteDef2 = CTERelationDef(WithCTE(cteRef.join(testRelation), Seq(cteDef))) - val cteRef2 = CTERelationRef(cteDef2.id, false, Nil) + val cteRef2 = CTERelationRef(cteDef2.id, false, Nil, false) val plan = WithCTE(cteRef2, Seq(cteDef2)).analyze val relations = plan.collect { case r: LocalRelation => r @@ -1541,7 +1542,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-43030: deduplicate CTE relation references") { val cteDef = CTERelationDef(testRelation.select($"a")) - val cteRef = CTERelationRef(cteDef.id, false, Nil) + val cteRef = CTERelationRef(cteDef.id, false, Nil, false) withClue("single reference") { val plan = WithCTE(cteRef.where($"a" > 1), Seq(cteDef)).analyze @@ -1564,7 +1565,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { withClue("CTE relation has duplicated attributes") { val cteDef = CTERelationDef(testRelation.select($"a", $"a")) - val cteRef = CTERelationRef(cteDef.id, false, Nil) + val cteRef = CTERelationRef(cteDef.id, false, Nil, false) val plan = WithCTE(cteRef.join(cteRef.select($"a")), Seq(cteDef)).analyze val refs = plan.collect { case r: CTERelationRef => r @@ -1576,14 +1577,14 @@ class AnalysisSuite extends AnalysisTest with Matchers { withClue("CTE relation has duplicate aliases") { val alias = Alias($"a", "x")() val cteDef = CTERelationDef(testRelation.select(alias, alias).where($"x" === 1)) - val cteRef = CTERelationRef(cteDef.id, false, Nil) + val cteRef = CTERelationRef(cteDef.id, false, Nil, false) // Should not fail with the assertion failure: Found duplicate rewrite attributes. WithCTE(cteRef.join(cteRef), Seq(cteDef)).analyze } withClue("references in both CTE relation definition and main query") { val cteDef2 = CTERelationDef(cteRef.where($"a" > 2)) - val cteRef2 = CTERelationRef(cteDef2.id, false, Nil) + val cteRef2 = CTERelationRef(cteDef2.id, false, Nil, false) val plan = WithCTE(cteRef.union(cteRef2), Seq(cteDef, cteDef2)).analyze val refs = plan.collect { case r: CTERelationRef => r @@ -1661,10 +1662,69 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(testRelation.select(ident2), testRelation.select($"a").analyze) } withClue("IDENTIFIER as table") { - val ident = PlanWithUnresolvedIdentifier(name, _ => testRelation) + val ident = new PlanWithUnresolvedIdentifier(name, _ => testRelation) checkAnalysis(ident.select($"a"), testRelation.select($"a").analyze) - val ident2 = PlanWithUnresolvedIdentifier(replaceable, _ => testRelation) + val ident2 = new PlanWithUnresolvedIdentifier(replaceable, _ => testRelation) checkAnalysis(ident2.select($"a"), testRelation.select($"a").analyze) } } + + test("SPARK-46064 Basic functionality of elimination for watermark node in batch query") { + val dfWithEventTimeWatermark = EventTimeWatermark($"ts", + IntervalUtils.fromIntervalString("10 seconds"), batchRelationWithTs) + + val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker) + + // EventTimeWatermark node is eliminated via EliminateEventTimeWatermark. + assert(!analyzed.exists(_.isInstanceOf[EventTimeWatermark])) + } + + test("SPARK-46064 EliminateEventTimeWatermark properly handles the case where the child of " + + "EventTimeWatermark changes the isStreaming flag during resolution") { + // UnresolvedRelation which is batch initially and will be resolved as streaming + val dfWithTempView = UnresolvedRelation(TableIdentifier("streamingTable")) + val dfWithEventTimeWatermark = EventTimeWatermark($"ts", + IntervalUtils.fromIntervalString("10 seconds"), dfWithTempView) + + val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker) + + // EventTimeWatermark node is NOT eliminated. + assert(analyzed.exists(_.isInstanceOf[EventTimeWatermark])) + } + + test("SPARK-46062: isStreaming flag is synced from CTE definition to CTE reference") { + val cteDef = CTERelationDef(streamingRelation.select($"a", $"ts")) + // Intentionally marking the flag _resolved to false, so that analyzer has a chance to sync + // the flag isStreaming on syncing the flag _resolved. + val cteRef = CTERelationRef(cteDef.id, _resolved = false, Nil, isStreaming = false) + val plan = WithCTE(cteRef, Seq(cteDef)).analyze + + val refs = plan.collect { + case r: CTERelationRef => r + } + assert(refs.length == 1) + assert(refs.head.resolved) + assert(refs.head.isStreaming) + } + + test("SPARK-47927: ScalaUDF output nullability") { + val udf = ScalaUDF( + function = (i: Int) => i + 1, + dataType = IntegerType, + children = $"a" :: Nil, + nullable = false, + inputEncoders = Seq(Some(ExpressionEncoder[Int]().resolveAndBind()))) + val plan = testRelation.select(udf.as("u")).select($"u").analyze + assert(plan.output.head.nullable) + } + + test("SPARK-49782: ResolveDataFrameDropColumns rule resolves complex UnresolvedAttribute") { + val function = UnresolvedFunction("trim", Seq(UnresolvedAttribute("i")), isDistinct = false) + val addColumnF = Project(Seq(UnresolvedAttribute("i"), Alias(function, "f")()), testRelation5) + // Drop column "f" via ResolveDataFrameDropColumns rule. + val inputPlan = DataFrameDropColumns(Seq(UnresolvedAttribute("f")), addColumnF) + // The expected Project (root node) should only have column "i". + val expectedPlan = Project(Seq(UnresolvedAttribute("i")), addColumnF).analyze + checkAnalysis(inputPlan, expectedPlan) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 997308c6ef44f..5152666473286 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -84,6 +84,8 @@ trait AnalysisTest extends PlanTest { createTempView(catalog, "TaBlE3", TestRelations.testRelation3, overrideIfExists = true) createGlobalTempView(catalog, "TaBlE4", TestRelations.testRelation4, overrideIfExists = true) createGlobalTempView(catalog, "TaBlE5", TestRelations.testRelation5, overrideIfExists = true) + createTempView(catalog, "streamingTable", TestRelations.streamingRelation, + overrideIfExists = true) new Analyzer(catalog) { override val extendedResolutionRules = extendedAnalysisRules } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 665204cd0c58e..08be4c8acc4b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -602,7 +602,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( "sqlExpr" -> "\"round(intField, intField)\"", - "inputName" -> "scala", + "inputName" -> "scale", "inputType" -> "\"INT\"", "inputExpr" -> "\"intField\"")) @@ -649,7 +649,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( "sqlExpr" -> "\"bround(intField, intField)\"", - "inputName" -> "scala", + "inputName" -> "scale", "inputType" -> "\"INT\"", "inputExpr" -> "\"intField\"")) checkError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala index ae32365e69bbc..1fd81349ac720 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -47,7 +47,7 @@ class LookupFunctionsSuite extends PlanTest { ignoreIfExists = false) val catalog = new SessionCatalog(externalCatalog, new SimpleFunctionRegistry) val catalogManager = new CatalogManager(new CustomV2SessionCatalog(catalog), catalog) - catalog.setCurrentDatabase("db1") + catalogManager.setCurrentNamespace(Array("db1")) try { val analyzer = new Analyzer(catalogManager) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveEncodersInUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveEncodersInUDFSuite.scala new file mode 100644 index 0000000000000..c7391ad9a4305 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveEncodersInUDFSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Exists, ScalaUDF} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Filter, MergeIntoTable, ReplaceData, UpdateAction} +import org.apache.spark.sql.catalyst.trees.TreePattern +import org.apache.spark.sql.connector.catalog.InMemoryRowLevelOperationTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class ResolveEncodersInUDFSuite extends AnalysisTest { + test("SPARK-48921: ScalaUDF encoders in subquery should be resolved for MergeInto") { + val table = new InMemoryRowLevelOperationTable("table", + StructType(StructField("a", IntegerType) :: + StructField("b", DoubleType) :: + StructField("c", StringType) :: Nil), + Array.empty, + new java.util.HashMap[String, String]() + ) + val relation = DataSourceV2Relation(table, + Seq(AttributeReference("a", IntegerType)(), + AttributeReference("b", DoubleType)(), + AttributeReference("c", StringType)()), + None, + None, + CaseInsensitiveStringMap.empty() + ) + + + val string = relation.output(2) + val udf = ScalaUDF((_: String) => "x", StringType, string :: Nil, + Option(ExpressionEncoder[String]()) :: Nil) + + val mergeIntoSource = + relation + .where($"c" === udf) + .select($"a", $"b") + .limit(1) + val cond = mergeIntoSource.output(0) == relation.output(0) && + mergeIntoSource.output(1) == relation.output(1) + + val mergePlan = MergeIntoTable( + relation, + mergeIntoSource, + cond, + Seq(UpdateAction(None, + Seq(Assignment(relation.output(0), relation.output(0)), + Assignment(relation.output(1), relation.output(1)), + Assignment(relation.output(2), relation.output(2))))), + Seq.empty, + Seq.empty) + + val replaceData = mergePlan.analyze.asInstanceOf[ReplaceData] + + val existsPlans = replaceData.groupFilterCondition.map(_.collect { + case e: Exists => + e.plan.collect { + case f: Filter if f.containsPattern(TreePattern.SCALA_UDF) => f + } + }.flatten) + + assert(existsPlans.isDefined) + + val udfs = existsPlans.get.map(_.expressions.flatMap(e => e.collect { + case s: ScalaUDF => + assert(s.inputEncoders.nonEmpty) + val encoder = s.inputEncoders.head + assert(encoder.isDefined) + assert(encoder.get.objDeserializer.resolved) + + s + })).flatten + assert(udfs.size == 1) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index 2e6c6e4eaf4c3..758b6b73e4eb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal, Rand} +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CurrentTimestamp, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.optimizer.{ComputeCurrentTime, EvalInlineTables} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} @@ -83,9 +84,10 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { assert(ResolveInlineTables(table) == table) } - test("convert") { + test("cast and execute") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted = ResolveInlineTables.convert(table) + val resolved = ResolveInlineTables.findCommonTypesAndCast(table) + val converted = ResolveInlineTables.earlyEvalIfPossible(resolved).asInstanceOf[LocalRelation] assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) @@ -93,11 +95,28 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { assert(converted.data(1).getLong(0) == 2L) } + test("cast and execute CURRENT_LIKE expressions") { + val table = UnresolvedInlineTable(Seq("c1"), Seq( + Seq(CurrentTimestamp()), Seq(CurrentTimestamp()))) + val casted = ResolveInlineTables.findCommonTypesAndCast(table) + val earlyEval = ResolveInlineTables.earlyEvalIfPossible(casted) + // Early eval should keep it in expression form. + assert(earlyEval.isInstanceOf[ResolvedInlineTable]) + + EvalInlineTables(ComputeCurrentTime(earlyEval)) match { + case LocalRelation(output, data, _) => + assert(output.map(_.dataType) == Seq(TimestampType)) + assert(data.size == 2) + // Make sure that both CURRENT_TIMESTAMP expressions are evaluated to the same value. + assert(data(0).getLong(0) == data(1).getLong(0)) + } + } + test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone.apply(table) - val LocalRelation(output, data, _) = ResolveInlineTables.apply(withTimeZone) + val LocalRelation(output, data, _) = EvalInlineTables(ResolveInlineTables.apply(withTimeZone)) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) @@ -107,11 +126,11 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted1 = ResolveInlineTables.convert(table1) + val converted1 = ResolveInlineTables.findCommonTypesAndCast(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) - val converted2 = ResolveInlineTables.convert(table2) + val converted2 = ResolveInlineTables.findCommonTypesAndCast(table2) assert(converted2.schema.fields(0).nullable) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala index b0d7ace646e2e..39cf298aec434 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 +import org.apache.spark.sql.catalyst.analysis.TestRelations.{testRelation, testRelation2} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { @@ -67,4 +68,40 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { testRelation2.groupBy(Literal(1), Literal(2))($"a", $"b")) } } + + test("SPARK-45920: group by ordinal repeated analysis") { + val plan = testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze + comparePlans( + plan, + testRelation.groupBy(Literal(1))(Literal(100).as("a")) + ) + + val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any)))) + // Copy the plan to reset its `analyzed` flag, so that analyzer rules will re-apply. + val copiedPlan = plan.transform { + case _: LocalRelation => testRelationWithData + } + comparePlans( + copiedPlan.analyze, // repeated analysis + testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")) + ) + } + + test("SPARK-47895: group by all repeated analysis") { + val plan = testRelation.groupBy($"all")(Literal(100).as("a")).analyze + comparePlans( + plan, + testRelation.groupBy(Literal(1))(Literal(100).as("a")) + ) + + val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any)))) + // Copy the plan to reset its `analyzed` flag, so that analyzer rules will re-apply. + val copiedPlan = plan.transform { + case _: LocalRelation => testRelationWithData + } + comparePlans( + copiedPlan.analyze, // repeated analysis + testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index d54237fcc1407..01b1a627e2871 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -68,4 +68,18 @@ object TestRelations { val mapRelation = LocalRelation( AttributeReference("map", MapType(IntegerType, IntegerType))()) + + val streamingRelation = LocalRelation( + Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("ts", TimestampType)() + ), + isStreaming = true) + + val batchRelationWithTs = LocalRelation( + Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("ts", TimestampType)() + ), + isStreaming = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index d91a080d8fe89..21a049e914182 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -304,6 +304,36 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan + test("SPARK-49352: Avoid redundant array transform for identical expression") { + def assertArrayField(fromType: ArrayType, toType: ArrayType, hasTransform: Boolean): Unit = { + val table = TestRelation(Seq($"a".int, $"arr".array(toType))) + val query = TestRelation(Seq($"arr".array(fromType), $"a".int)) + + val writePlan = byName(table, query).analyze + + assertResolved(writePlan) + checkAnalysis(writePlan, writePlan) + + val transform = writePlan.children.head.expressions.exists { e => + e.find { + case _: ArrayTransform => true + case _ => false + }.isDefined + } + if (hasTransform) { + assert(transform) + } else { + assert(!transform) + } + } + + assertArrayField(ArrayType(LongType), ArrayType(LongType), hasTransform = false) + assertArrayField( + ArrayType(new StructType().add("x", "int").add("y", "int")), + ArrayType(new StructType().add("y", "int").add("x", "byte")), + hasTransform = true) + } + test("SPARK-33136: output resolved on complex types for V2 write commands") { def assertTypeCompatibility(name: String, fromType: DataType, toType: DataType): Unit = { val table = TestRelation(StructType(Seq(StructField("a", toType)))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index b668386bc472d..05c1c33520dac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1880,7 +1880,8 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { conf.setConf(StaticSQLConf.METADATA_CACHE_TTL_SECONDS, 1L) withConfAndEmptyCatalog(conf) { catalog => - val table = QualifiedTableName(catalog.getCurrentDatabase, "test") + val table = QualifiedTableName( + CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "test") // First, make sure the test table is not cached. assert(catalog.getCachedTable(table) === null) @@ -1899,13 +1900,14 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { test("SPARK-34197: refreshTable should not invalidate the relation cache for temporary views") { withBasicCatalog { catalog => createTempView(catalog, "tbl1", Range(1, 10, 1, 10), false) - val qualifiedName1 = QualifiedTableName("default", "tbl1") + val qualifiedName1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "tbl1") catalog.cacheTable(qualifiedName1, Range(1, 10, 1, 10)) catalog.refreshTable(TableIdentifier("tbl1")) assert(catalog.getCachedTable(qualifiedName1) != null) createGlobalTempView(catalog, "tbl2", Range(2, 10, 1, 10), false) - val qualifiedName2 = QualifiedTableName(catalog.globalTempViewManager.database, "tbl2") + val qualifiedName2 = + QualifiedTableName(SESSION_CATALOG_NAME, catalog.globalTempViewManager.database, "tbl2") catalog.cacheTable(qualifiedName2, Range(2, 10, 1, 10)) catalog.refreshTable(TableIdentifier("tbl2", Some(catalog.globalTempViewManager.database))) assert(catalog.getCachedTable(qualifiedName2) != null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala index fcb10c98243d9..d94b74cff3032 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala @@ -76,8 +76,6 @@ class CSVExprUtilsSuite extends SparkFunSuite { // tab in the middle of some other letters ("""ba\tr""", Some("ba\tr"), None), // null character, expressed in Unicode literal syntax - ("\u0000", Some("\u0000"), None), - // and specified directly ("\u0000", Some("\u0000"), None) ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index acedf7998c2d5..fb91200557a65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -263,4 +263,14 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { inferSchema = new CSVInferSchema(options) assert(inferSchema.inferField(DateType, "2012_12_12") == DateType) } + + test("SPARK-45433: inferring the schema when timestamps do not match specified timestampFormat" + + " with only one row") { + val options = new CSVOptions( + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), + columnPruning = false, + defaultTimeZoneId = "UTC") + val inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(NullType, "2884-06-24T02:45:51.138") == StringType) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 9d2051b01d62e..724a91806c7e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -477,6 +477,18 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(Option.empty[Int], "empty option of int") encodeDecodeTest(Option("abc"), "option of string") encodeDecodeTest(Option.empty[String], "empty option of string") + encodeDecodeTest(Seq(Some(Seq(0))), "SPARK-45896: seq of option of seq") + encodeDecodeTest(Map(0 -> Some(Seq(0))), "SPARK-45896: map of option of seq") + encodeDecodeTest(Seq(Some(Timestamp.valueOf("2023-01-01 00:00:00"))), + "SPARK-45896: seq of option of timestamp") + encodeDecodeTest(Map(0 -> Some(Timestamp.valueOf("2023-01-01 00:00:00"))), + "SPARK-45896: map of option of timestamp") + encodeDecodeTest(Seq(Some(Date.valueOf("2023-01-01"))), + "SPARK-45896: seq of option of date") + encodeDecodeTest(Map(0 -> Some(Date.valueOf("2023-01-01"))), + "SPARK-45896: map of option of date") + encodeDecodeTest(Seq(Some(BigDecimal(200))), "SPARK-45896: seq of option of bigdecimal") + encodeDecodeTest(Map(0 -> Some(BigDecimal(200))), "SPARK-45896: map of option of bigdecimal") encodeDecodeTest(ScroogeLikeExample(1), "SPARK-40385 class with only a companion object constructor") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e21793ab506c4..7a80188d445de 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.RoundingMode import java.sql.{Date, Timestamp} import java.time.{Duration, Period} import java.time.temporal.ChronoUnit @@ -225,6 +226,121 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-45786: Decimal multiply, divide, remainder, quot") { + // Some known cases + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)), + Literal(Decimal(BigDecimal("-0.4652"), 4, 4)) + ), + Decimal(BigDecimal("6568635674732509803675414794505.574763")) + ) + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)), + Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25)) + ), + Decimal(BigDecimal("1367249507675382200.164877854336665327")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)), + Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1)) + ), + Decimal(BigDecimal("-3.237520E-31")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)), + Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3)) + ), + Decimal(BigDecimal("7.21642358550E-25")) + ) + + // Random tests + val rand = scala.util.Random + def makeNum(p: Int, s: Int): String = { + val int1 = rand.nextLong() + val int2 = rand.nextLong().abs + val frac1 = rand.nextLong().abs + val frac2 = rand.nextLong().abs + s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s) + } + + (0 until 100).foreach { _ => + val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38 + val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1 + val p2 = rand.nextInt(38) + 1 + val s2 = rand.nextInt(p2 + 1) + + val n1 = makeNum(p1, s1) + val n2 = makeNum(p2, s2) + + val mulActual = Multiply( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) + + val divActual = Divide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val divExact = new java.math.BigDecimal(n1) + .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN) + + val remActual = Remainder( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2)) + + val quotActual = IntegralDivide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val quotExact = + new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2)) + + Seq(true, false).foreach { allowPrecLoss => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { + val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2) + val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) + val mulExpected = + if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult + checkEvaluationOrException(mulActual, mulExpected) + + val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2) + val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP)) + val divExpected = + if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult + checkEvaluationOrException(divActual, divExpected) + + val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2) + val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP)) + val remExpected = + if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult + checkEvaluationOrException(remActual, remExpected) + + val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2) + val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP)) + val quotExpected = + if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult + checkEvaluationOrException(quotActual, + if (quotExpected == null) null else quotExpected.toLong) + } + } + + def checkEvaluationOrException(actual: BinaryArithmetic, expected: Any): Unit = + if (SQLConf.get.ansiEnabled && expected == null) { + checkExceptionInExpression[SparkArithmeticException](actual, + "NUMERIC_VALUE_OUT_OF_RANGE") + } else { + checkEvaluation(actual, expected) + } + } + } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { testFunc(_.toDouble) testFunc(Decimal(_)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala index 4cd5f3e861ac8..5bd1bc346c02f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala @@ -133,6 +133,47 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("BitCount") { + // null + val nullLongLiteral = Literal.create(null, LongType) + val nullIntLiteral = Literal.create(null, IntegerType) + val nullBooleanLiteral = Literal.create(null, BooleanType) + checkEvaluation(BitwiseCount(nullLongLiteral), null) + checkEvaluation(BitwiseCount(nullIntLiteral), null) + checkEvaluation(BitwiseCount(nullBooleanLiteral), null) + + // boolean + checkEvaluation(BitwiseCount(Literal(true)), 1) + checkEvaluation(BitwiseCount(Literal(false)), 0) + + // byte/tinyint + checkEvaluation(BitwiseCount(Literal(1.toByte)), 1) + checkEvaluation(BitwiseCount(Literal(2.toByte)), 1) + checkEvaluation(BitwiseCount(Literal(3.toByte)), 2) + + // short/smallint + checkEvaluation(BitwiseCount(Literal(1.toShort)), 1) + checkEvaluation(BitwiseCount(Literal(2.toShort)), 1) + checkEvaluation(BitwiseCount(Literal(3.toShort)), 2) + + // int + checkEvaluation(BitwiseCount(Literal(1)), 1) + checkEvaluation(BitwiseCount(Literal(2)), 1) + checkEvaluation(BitwiseCount(Literal(3)), 2) + + // long/bigint + checkEvaluation(BitwiseCount(Literal(1L)), 1) + checkEvaluation(BitwiseCount(Literal(2L)), 1) + checkEvaluation(BitwiseCount(Literal(3L)), 2) + + // negative num + checkEvaluation(BitwiseCount(Literal(-1L)), 64) + + // edge value + checkEvaluation(BitwiseCount(Literal(9223372036854775807L)), 63) + checkEvaluation(BitwiseCount(Literal(-9223372036854775808L)), 1) + } + test("BitGet") { val nullLongLiteral = Literal.create(null, LongType) val nullIntLiteral = Literal.create(null, IntegerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 0e22b0d2876d7..89175ea1970cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -338,4 +338,17 @@ class CanonicalizeSuite extends SparkFunSuite { SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString) } + + test("toJSON works properly with MultiCommutativeOp") { + val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD) + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "1") + + val d = Decimal(1.2) + val literal1 = Literal.create(d, DecimalType(2, 1)) + val literal2 = Literal.create(d, DecimalType(2, 1)) + val literal3 = Literal.create(d, DecimalType(3, 2)) + val op = Add(literal1, Add(literal2, literal3)) + assert(op.canonicalized.toJSON.nonEmpty) + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 0172fd9b3e4c7..4352d5bc9c6bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -22,8 +22,6 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period} import java.time.temporal.ChronoUnit import java.util.{Calendar, Locale, TimeZone} -import scala.collection.parallel.immutable.ParVector - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -42,6 +40,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND import org.apache.spark.sql.types.UpCastRule.numericPrecedence import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ThreadUtils /** * Common test suite for [[Cast]] with ansi mode on and off. It only includes test cases that work @@ -126,7 +125,11 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - new ParVector(ALL_TIMEZONES.toVector).foreach { zid => + ThreadUtils.parmap( + ALL_TIMEZONES, + prefix = "CastSuiteBase-cast-string-to-timestamp", + maxThreads = Runtime.getRuntime.availableProcessors + ) { zid => def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = { checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)), expected) } @@ -1171,7 +1174,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval => val dataType = YearMonthIntervalType() val expectedMsg = s"Interval string does not match year-month format of " + - s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + s"${IntervalUtils.supportedFormat(("year-month", dataType.startField, dataType.endField)) .map(format => s"`$format`").mkString(", ")} " + s"when cast to ${dataType.typeName}: $interval" checkExceptionInExpression[IllegalArgumentException]( @@ -1191,7 +1194,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { ("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH))) .foreach { case (interval, dataType) => val expectedMsg = s"Interval string does not match year-month format of " + - s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + s"${IntervalUtils.supportedFormat(("year-month", dataType.startField, dataType.endField)) .map(format => s"`$format`").mkString(", ")} " + s"when cast to ${dataType.typeName}: $interval" checkExceptionInExpression[IllegalArgumentException]( @@ -1311,7 +1314,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { ("1.23", DayTimeIntervalType(MINUTE))) .foreach { case (interval, dataType) => val expectedMsg = s"Interval string does not match day-time format of " + - s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + s"${IntervalUtils.supportedFormat(("day-time", dataType.startField, dataType.endField)) .map(format => s"`$format`").mkString(", ")} " + s"when cast to ${dataType.typeName}: $interval, " + s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + @@ -1335,7 +1338,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { ("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND))) .foreach { case (interval, dataType) => val expectedMsg = "Interval string does not match day-time format of " + - s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + s"${IntervalUtils.supportedFormat(("day-time", dataType.startField, dataType.endField)) .map(format => s"`$format`").mkString(", ")} " + s"when cast to ${dataType.typeName}: $interval, " + s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala index 1dbf03b1538a6..502ae399ec16a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala @@ -507,6 +507,8 @@ class CastWithAnsiOffSuite extends CastSuiteBase { checkEvaluation(cast(1.0 / 0.0, TimestampType), null) checkEvaluation(cast(Float.NaN, TimestampType), null) checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) + checkEvaluation(cast(Literal(Long.MaxValue), TimestampType), Long.MaxValue) + checkEvaluation(cast(Literal(Long.MinValue), TimestampType), Long.MinValue) } test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1787f6ac72dd4..99eece31a1efc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds, import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -769,10 +769,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // test sequence boundaries checking - checkExceptionInExpression[IllegalArgumentException]( - new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), - EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") - checkExceptionInExpression[IllegalArgumentException]( new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0") checkExceptionInExpression[IllegalArgumentException]( @@ -782,6 +778,44 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkExceptionInExpression[IllegalArgumentException]( new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1") + // SPARK-43393: test Sequence overflow checking + checkErrorInExpression[SparkRuntimeException]( + new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), + errorClass = "_LEGACY_ERROR_TEMP_2161", + parameters = Map( + "count" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } + 1).toString, + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + checkErrorInExpression[SparkRuntimeException]( + new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)), + errorClass = "_LEGACY_ERROR_TEMP_2161", + parameters = Map( + "count" -> (BigInt(Long.MaxValue) + 1).toString, + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + checkErrorInExpression[SparkRuntimeException]( + new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)), + errorClass = "_LEGACY_ERROR_TEMP_2161", + parameters = Map( + "count" -> ((0 - BigInt(Long.MinValue)) + 1).toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + checkErrorInExpression[SparkRuntimeException]( + new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), Literal(1L)), + errorClass = "_LEGACY_ERROR_TEMP_2161", + parameters = Map( + "count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString, + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + checkErrorInExpression[SparkRuntimeException]( + new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), Literal(-1L)), + errorClass = "_LEGACY_ERROR_TEMP_2161", + parameters = Map( + "count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString, + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + checkErrorInExpression[SparkRuntimeException]( + new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)), + errorClass = "_LEGACY_ERROR_TEMP_2161", + parameters = Map( + "count" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString, + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + // test sequence with one element (zero step or equal start and stop) checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 5be0cae4a22f1..84520846a8d48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -71,10 +71,15 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB new ArrayBasedMapData(keyArray, valueArray) } + protected def replace(expr: Expression): Expression = expr match { + case r: RuntimeReplaceable => replace(r.replacement) + case _ => expr.mapChildren(replace) + } + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone - val expr = resolver.resolveTimeZones(expression) + val expr = resolver.resolveTimeZones(replace(expression)) assert(expr.resolved) serializer.deserialize(serializer.serialize(expr)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 006c4a7805688..53a76e2cb9ff1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -464,6 +464,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val b = $"b".binary.at(0) val bytes = Array[Byte](1, 2, 3, 4) + assert(!Base64(Literal(bytes)).nullable) + assert(Base64(Literal.create(null, BinaryType)).nullable) + assert(Base64(Literal(bytes).castNullable()).nullable) + assert(!UnBase64(Literal("AQIDBA==")).nullable) + assert(UnBase64(Literal.create(null, StringType)).nullable) + assert(UnBase64(Literal("AQIDBA==").castNullable()).nullable) + checkEvaluation(Base64(Literal(bytes)), "AQIDBA==", create_row("abdef")) checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef")) checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef")) @@ -506,6 +513,23 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")) :: Nil) } + test("SPARK-47307: base64 encoding without chunking") { + val longString = "a" * 58 + val encoded = "YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYQ==" + withSQLConf(SQLConf.CHUNK_BASE64_STRING_ENABLED.key -> "false") { + checkEvaluation(Base64(Literal(longString.getBytes)), encoded) + } + val chunkEncoded = + s"YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFh\r\nYQ==" + withSQLConf(SQLConf.CHUNK_BASE64_STRING_ENABLED.key -> "true") { + checkEvaluation(Base64(Literal(longString.getBytes)), chunkEncoded) + } + + // check if unbase64 works well for chunked and non-chunked encoded strings + checkEvaluation(StringDecode(UnBase64(Literal(encoded)), Literal("utf-8")), longString) + checkEvaluation(StringDecode(UnBase64(Literal(chunkEncoded)), Literal("utf-8")), longString) + } + test("initcap unit test") { checkEvaluation(InitCap(Literal.create(null, StringType)), null) checkEvaluation(InitCap(Literal("a b")), "A B") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index f369635a32671..e9faeba2411ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -494,6 +494,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel checkShortcut(Or(equal, Literal(true)), 1) checkShortcut(Not(And(equal, Literal(false))), 1) } + + test("Equivalent ternary expressions have different children") { + val add1 = Add(Add(Literal(1), Literal(2)), Literal(3)) + val add2 = Add(Add(Literal(3), Literal(1)), Literal(2)) + val conditions1 = (GreaterThan(add1, Literal(3)), Literal(1)) :: + (GreaterThan(add2, Literal(0)), Literal(2)) :: Nil + + val caseWhenExpr1 = CaseWhen(conditions1, Literal(0)) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) + assert(equivalence1.getCommonSubexpressions.size == 1) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAggSuite.scala new file mode 100644 index 0000000000000..daf3ede0d0369 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAggSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.DoubleType + +class CentralMomentAggSuite extends TestWithAndWithoutCodegen { + val input = AttributeReference("input", DoubleType, nullable = true)() + + testBothCodegenAndInterpreted("SPARK-46189: pandas_kurtosis eval") { + val evaluator = DeclarativeAggregateEvaluator(PandasKurtosis(input), Seq(input)) + val buffer = evaluator.update( + InternalRow(1.0d), + InternalRow(2.0d), + InternalRow(3.0d), + InternalRow(7.0d), + InternalRow(9.0d), + InternalRow(8.0d)) + val result = evaluator.eval(buffer) + assert(result === InternalRow(-2.5772889417360285d)) + } + + testBothCodegenAndInterpreted("SPARK-46189: pandas_skew eval") { + val evaluator = DeclarativeAggregateEvaluator(PandasSkewness(input), Seq(input)) + val buffer = evaluator.update( + InternalRow(1.0d), + InternalRow(2.0d), + InternalRow(2.0d), + InternalRow(2.0d), + InternalRow(2.0d), + InternalRow(100.0d)) + val result = evaluator.eval(buffer) + assert(result === InternalRow(2.4489389171333733d)) + } + + testBothCodegenAndInterpreted("SPARK-46189: pandas_stddev eval") { + val evaluator = DeclarativeAggregateEvaluator(PandasStddev(input, 1), Seq(input)) + val buffer = evaluator.update( + InternalRow(1.0d), + InternalRow(2.0d), + InternalRow(3.0d), + InternalRow(7.0d), + InternalRow(9.0d), + InternalRow(8.0d)) + val result = evaluator.eval(buffer) + assert(result === InternalRow(3.40587727318528d)) + } + + testBothCodegenAndInterpreted("SPARK-46189: pandas_variance eval") { + val evaluator = DeclarativeAggregateEvaluator(PandasVariance(input, 1), Seq(input)) + val buffer = evaluator.update( + InternalRow(1.0d), + InternalRow(2.0d), + InternalRow(3.0d), + InternalRow(7.0d), + InternalRow(9.0d), + InternalRow(8.0d)) + val result = evaluator.eval(buffer) + assert(result === InternalRow(11.6d)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CovarianceAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CovarianceAggSuite.scala new file mode 100644 index 0000000000000..2df053184c2b4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CovarianceAggSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.DoubleType + +class CovarianceAggSuite extends TestWithAndWithoutCodegen { + val a = AttributeReference("a", DoubleType, nullable = true)() + val b = AttributeReference("b", DoubleType, nullable = true)() + + testBothCodegenAndInterpreted("SPARK-46189: pandas_covar eval") { + val evaluator = DeclarativeAggregateEvaluator(PandasCovar(a, b, 1), Seq(a, b)) + val buffer = evaluator.update( + InternalRow(1.0d, 1.0d), + InternalRow(2.0d, 2.0d), + InternalRow(3.0d, 3.0d), + InternalRow(7.0d, 7.0d), + InternalRow(9.0, 9.0), + InternalRow(8.0d, 6.0)) + val result = evaluator.eval(buffer) + assert(result === InternalRow(10.4d)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala index b0f55b3b5c443..ac80e1419a99d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -17,24 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, SafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, MutableProjection} /** * Evaluator for a [[DeclarativeAggregate]]. */ case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { - lazy val initializer = SafeProjection.create(function.initialValues) + lazy val initializer = MutableProjection.create(function.initialValues) - lazy val updater = SafeProjection.create( + lazy val updater = MutableProjection.create( function.updateExpressions, function.aggBufferAttributes ++ input) - lazy val merger = SafeProjection.create( + lazy val merger = MutableProjection.create( function.mergeExpressions, function.aggBufferAttributes ++ function.inputAggBufferAttributes) - lazy val evaluator = SafeProjection.create( + lazy val evaluator = MutableProjection.create( function.evaluateExpression :: Nil, function.aggBufferAttributes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/TestWithAndWithoutCodegen.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/TestWithAndWithoutCodegen.scala new file mode 100644 index 0000000000000..b43b160146eb4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/TestWithAndWithoutCodegen.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf + +trait TestWithAndWithoutCodegen extends SparkFunSuite with SQLHelper { + def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index 8d9f90a1a87c5..b25b191f8b136 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -185,6 +185,11 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { testExpr("b1b2b3c1c2", "a/*[@class='bb']/text()", Seq("b1", "c1")) + checkEvaluation( + Coalesce(Seq( + GetArrayItem(XPathList(Literal(""), Literal("a")), Literal(0)), + Literal("nul"))), "nul") + testNullAndErrorBehavior(testExpr) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala index 8290b38e33934..81a4858dce82a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala @@ -112,4 +112,12 @@ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper { checkType(Map("inferTimestamp" -> "true"), json, TimestampType) checkType(Map("inferTimestamp" -> "false"), json, StringType) } + + test("SPARK-45433: inferring the schema when timestamps do not match specified timestampFormat" + + " with only one row") { + checkType( + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "inferTimestamp" -> "true"), + """{"a": "2884-06-24T02:45:51.138"}""", + StringType) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ee56d1fa9acd3..2ebb43d4fba3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -1190,7 +1190,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: no pushdown on watermark attribute #1") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation(attrA, $"b".timestamp, attrC) + val relation = LocalRelation(Seq(attrA, $"b".timestamp, attrC), Nil, isStreaming = true) // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. @@ -1205,7 +1205,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: no pushdown for nondeterministic filter") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation(attrA, attrB, $"c".timestamp) + val relation = LocalRelation(Seq(attrA, attrB, $"c".timestamp), Nil, isStreaming = true) // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. @@ -1221,7 +1221,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: full pushdown") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation(attrA, attrB, $"c".timestamp) + val relation = LocalRelation(Seq(attrA, attrB, $"c".timestamp), Nil, isStreaming = true) // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. @@ -1236,7 +1236,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: no pushdown on watermark attribute #2") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation($"a".timestamp, attrB, attrC) + val relation = LocalRelation(Seq($"a".timestamp, attrB, attrC), Nil, isStreaming = true) val originalQuery = EventTimeWatermark($"a", interval, relation) .where($"a" === new java.sql.Timestamp(0) && $"b" === 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala index 8af0e02855b12..13e138414781f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala @@ -42,7 +42,8 @@ class MergeScalarSubqueriesSuite extends PlanTest { } private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = { - GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex) + GetStructField(ScalarSubquery( + CTERelationRef(cteIndex, _resolved = true, output, isStreaming = false)), fieldIndex) .as("scalarsubquery()") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index cb6b9ac8d8bec..6ce394dbd68be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -861,6 +861,27 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { // The plan is expected to be unchanged. comparePlans(plan, RemoveNoopOperators.apply(optimized.get)) } + + test("SPARK-48428: Do not pushdown when attr is used in expression with mutliple references") { + val query = contact + .limit(5) + .select( + GetStructField(GetStructField(CreateStruct(Seq($"id", $"employer")), 1), 0), + $"employer.id") + .analyze + + val optimized = Optimize.execute(query) + + val expected = contact + .select($"id", $"employer") + .limit(5) + .select( + GetStructField(GetStructField(CreateStruct(Seq($"id", $"employer")), 1), 0), + $"employer.id") + .analyze + + comparePlans(optimized, expected) + } } object NestedColumnAliasingSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala index c185de4c05d88..eed06da609f8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala @@ -307,4 +307,21 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { comparePlans(optimized, query.analyze) } } + + test("SPARK-49743: prune unnecessary columns from GetArrayStructFields does not change schema") { + val options = Map.empty[String, String] + val schema = ArrayType(StructType.fromDDL("a int, b int"), containsNull = true) + + val field = StructField("A", IntegerType) // Instead of "a", use "A" to test case sensitivity. + val query = testRelation2 + .select(GetArrayStructFields( + JsonToStructs(schema, options, $"json"), field, 0, 2, true).as("a")) + val optimized = Optimizer.execute(query.analyze) + + val prunedSchema = ArrayType(StructType.fromDDL("a int"), containsNull = true) + val expected = testRelation2 + .select(GetArrayStructFields( + JsonToStructs(prunedSchema, options, $"json"), field, 0, 1, true).as("a")).analyze + comparePlans(optimized, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index e8d2ca1ff75de..451236162343b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -21,18 +21,19 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Literal, UnspecifiedFrame} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal, UnspecifiedFrame} import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{Expand, LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, MetadataBuilder} -class PropagateEmptyRelationSuite extends PlanTest { +class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("PropagateEmptyRelation", Once, + Batch("PropagateEmptyRelation", FixedPoint(1), CombineUnions, ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, @@ -45,7 +46,7 @@ class PropagateEmptyRelationSuite extends PlanTest { object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { val batches = - Batch("OptimizeWithoutPropagateEmptyRelation", Once, + Batch("OptimizeWithoutPropagateEmptyRelation", FixedPoint(1), CombineUnions, ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, @@ -216,8 +217,61 @@ class PropagateEmptyRelationSuite extends PlanTest { .where($"a" =!= 200) .orderBy($"a".asc) - val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation(output, isStreaming = true) + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true") { + val optimized = Optimize.execute(query.analyze) + val correctAnswer = LocalRelation(output, isStreaming = true) + comparePlans(optimized, correctAnswer) + } + + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "false") { + val optimized = Optimize.execute(query.analyze) + val correctAnswer = relation + .where(false) + .where($"a" > 1) + .select($"a") + .where($"a" =!= 200) + .orderBy($"a".asc).analyze + comparePlans(optimized, correctAnswer) + } + } + + test("SPARK-47305 correctly tag isStreaming when propagating empty relation " + + "with the plan containing batch and streaming") { + val data = Seq(Row(1)) + + val outputForStream = Seq($"a".int) + val schemaForStream = DataTypeUtils.fromAttributes(outputForStream) + val converterForStream = CatalystTypeConverters.createToCatalystConverter(schemaForStream) + + val outputForBatch = Seq($"b".int) + val schemaForBatch = DataTypeUtils.fromAttributes(outputForBatch) + val converterForBatch = CatalystTypeConverters.createToCatalystConverter(schemaForBatch) + + val streamingRelation = LocalRelation( + outputForStream, + data.map(converterForStream(_).asInstanceOf[InternalRow]), + isStreaming = true) + val batchRelation = LocalRelation( + outputForBatch, + data.map(converterForBatch(_).asInstanceOf[InternalRow]), + isStreaming = false) + + val query = streamingRelation + .join(batchRelation.where(false).select($"b"), LeftOuter, + Some(EqualTo($"a", $"b"))) + + val analyzedQuery = query.analyze + + val optimized = Optimize.execute(analyzedQuery) + // This is to deal with analysis for join condition. We expect the analyzed plan to contain + // filter and projection in batch relation, and know they will go away after optimization. + // The point to check here is that the node is replaced with "empty" LocalRelation, but the + // flag `isStreaming` is properly propagated. + val correctAnswer = analyzedQuery transform { + case Project(_, Filter(_, l: LocalRelation)) => l.copy(data = Seq.empty) + } comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index b81a57f4f8cd5..66ded338340f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -174,4 +174,38 @@ class PruneFiltersSuite extends PlanTest { testRelation.where(!$"a".attr.in(1, 3, 5) && $"a".attr === 7 && $"b".attr === 1) .where(Rand(10) > 0.1 && Rand(10) < 1.1).analyze) } + + test("Streaming relation is not lost under true filter") { + Seq("true", "false").foreach(x => withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> x) { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 > 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.select($"a").analyze + comparePlans(optimized, correctAnswer) + }) + } + + test("Streaming relation is not lost under false filter") { + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true") { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 < 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.select($"a").analyze + comparePlans(optimized, correctAnswer) + } + + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "false") { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 < 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.where(10 < 5).select($"a").analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index cd19e5062ae1f..8a0a0466ca741 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.MetadataBuilder class RemoveRedundantAliasAndProjectSuite extends PlanTest { @@ -130,4 +131,51 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest { correlated = false) comparePlans(optimized, expected) } + + test("SPARK-46640: do not remove outer references from a subquery expression") { + val a = $"a".int + val a_alias = Alias(a, "a")() + val a_alias_attr = a_alias.toAttribute + val b = $"b".int + + // The original input query + // Filter exists [a#1 && (a#1 = b#2)] + // : +- LocalRelation , [b#2] + // +- Project [a#0 AS a#1] + // +- LocalRelation , [a#0] + val query = Filter( + Exists( + LocalRelation(b), + outerAttrs = Seq(a_alias_attr), + joinCond = Seq(EqualTo(a_alias_attr, b)) + ), + Project(Seq(a_alias), LocalRelation(a)) + ) + + // The alias would not be removed if excluding subquery references is enabled. + val expectedWhenExcluded = query + + // The alias would have been removed if excluding subquery references is disabled. + // Filter exists [a#0 && (a#0 = b#2)] + // : +- LocalRelation , [b#2] + // +- LocalRelation , [a#0] + val expectedWhenNotExcluded = Filter( + Exists( + LocalRelation(b), + outerAttrs = Seq(a), + joinCond = Seq(EqualTo(a, b)) + ), + LocalRelation(a) + ) + + withSQLConf(SQLConf.EXCLUDE_SUBQUERY_EXP_REFS_FROM_REMOVE_REDUNDANT_ALIASES.key -> "true") { + val optimized = Optimize.execute(query) + comparePlans(optimized, expectedWhenExcluded) + } + + withSQLConf(SQLConf.EXCLUDE_SUBQUERY_EXP_REFS_FROM_REMOVE_REDUNDANT_ALIASES.key -> "false") { + val optimized = Optimize.execute(query) + comparePlans(optimized, expectedWhenNotExcluded) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 5d81e96a8e583..cb9577e050d04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -265,4 +265,35 @@ class ReplaceOperatorSuite extends PlanTest { Join(basePlan, otherPlan, LeftAnti, Option(condition), JoinHint.NONE)).analyze comparePlans(result, correctAnswer) } + + test("SPARK-46763: ReplaceDeduplicateWithAggregate non-grouping keys with duplicate attributes") { + val a = $"a".int + val b = $"b".int + val first_a = Alias(new First(a).toAggregateExpression(), a.name)() + + val query = Project( + projectList = Seq(a, b), + Deduplicate( + keys = Seq(b), + child = Project( + projectList = Seq(a, a, b), + child = LocalRelation(Seq(a, b)) + ) + ) + ).analyze + + val result = Optimize.execute(query) + val correctAnswer = Project( + projectList = Seq(first_a.toAttribute, b), + Aggregate( + Seq(b), + Seq(first_a, first_a, b), + Project( + projectList = Seq(a, a, b), + child = LocalRelation(Seq(a, b)) + ) + ) + ).analyze + comparePlans(result, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index ac136dfb898ef..4d31999ded655 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Literal, Round} import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} @@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest { case _ => fail(s"Plan is not rewritten:\n$rewrite") } } + + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { + val relation = testRelation2 + .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d") + val input = relation + .groupBy($"a", $"gb")( + countDistinct($"b").as("agg1"), + countDistinct($"d").as("agg2"), + Round(sum($"c").as("sum1"), 6)).analyze + val rewriteFold = FoldablePropagation(input) + // without the fix, the below produces an unresolved plan + val rewrite = RewriteDistinctAggregates(rewriteFold) + if (!rewrite.resolved) { + fail(s"Plan is not as expected:\n$rewrite") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 31fd232181a4f..176c24d4e100f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -39,7 +39,16 @@ class DDLParserSuite extends AnalysisTest { } private def parseCompare(sql: String, expected: LogicalPlan): Unit = { - comparePlans(parsePlan(sql), expected, checkAnalysis = false) + // We don't care the write privileges in this suite. + val parsed = parsePlan(sql).transform { + case u: UnresolvedRelation => u.clearWritePrivileges + case i: InsertIntoStatement => + i.table match { + case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges) + case _ => i + } + } + comparePlans(parsed, expected, checkAnalysis = false) } private def internalException(sqlText: String): SparkThrowable = { @@ -2356,6 +2365,18 @@ class DDLParserSuite extends AnalysisTest { stop = 42)) } + test("SPARK-46610: throw exception when no value for a key in create table options") { + val createTableSql = "create table test_table using my_data_source options (password)" + checkError( + exception = parseException(createTableSql), + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "A value must be specified for the key: password."), + context = ExpectedContext( + fragment = createTableSql, + start = 0, + stop = 62)) + } + test("UNCACHE TABLE") { comparePlans( parsePlan("UNCACHE TABLE a.b.c"), @@ -2602,15 +2623,15 @@ class DDLParserSuite extends AnalysisTest { val timestampTypeSql = s"INSERT INTO t PARTITION(part = timestamp'$timestamp') VALUES('a')" val binaryTypeSql = s"INSERT INTO t PARTITION(part = X'$binaryHexStr') VALUES('a')" - comparePlans(parsePlan(dateTypeSql), insertPartitionPlan("2019-01-02")) + parseCompare(dateTypeSql, insertPartitionPlan("2019-01-02")) withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { - comparePlans(parsePlan(intervalTypeSql), insertPartitionPlan(interval)) + parseCompare(intervalTypeSql, insertPartitionPlan(interval)) } - comparePlans(parsePlan(ymIntervalTypeSql), insertPartitionPlan("INTERVAL '1-2' YEAR TO MONTH")) - comparePlans(parsePlan(dtIntervalTypeSql), + parseCompare(ymIntervalTypeSql, insertPartitionPlan("INTERVAL '1-2' YEAR TO MONTH")) + parseCompare(dtIntervalTypeSql, insertPartitionPlan("INTERVAL '1 02:03:04.128462' DAY TO SECOND")) - comparePlans(parsePlan(timestampTypeSql), insertPartitionPlan(timestamp)) - comparePlans(parsePlan(binaryTypeSql), insertPartitionPlan(binaryStr)) + parseCompare(timestampTypeSql, insertPartitionPlan(timestamp)) + parseCompare(binaryTypeSql, insertPartitionPlan(binaryStr)) } test("SPARK-38335: Implement parser support for DEFAULT values for columns in tables") { @@ -2705,12 +2726,12 @@ class DDLParserSuite extends AnalysisTest { // In each of the following cases, the DEFAULT reference parses as an unresolved attribute // reference. We can handle these cases after the parsing stage, at later phases of analysis. - comparePlans(parsePlan("VALUES (1, 2, DEFAULT) AS val"), + parseCompare("VALUES (1, 2, DEFAULT) AS val", SubqueryAlias("val", UnresolvedInlineTable(Seq("col1", "col2", "col3"), Seq(Seq(Literal(1), Literal(2), UnresolvedAttribute("DEFAULT")))))) - comparePlans(parsePlan( - "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)"), + parseCompare( + "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)", InsertIntoStatement( UnresolvedRelation(Seq("t")), Map("part" -> Some("2019-01-02")), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 4a5d0a0ae29fa..acc5a6ebddd2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StringType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Parser test cases for rules defined in [[CatalystSqlParser]] / [[AstBuilder]]. @@ -38,7 +39,16 @@ class PlanParserSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.dsl.plans._ private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false) + // We don't care the write privileges in this suite. + val parsed = parsePlan(sqlCommand).transform { + case u: UnresolvedRelation => u.clearWritePrivileges + case i: InsertIntoStatement => + i.table match { + case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges) + case _ => i + } + } + comparePlans(parsed, plan, checkAnalysis = false) } private def parseException(sqlText: String): SparkThrowable = { @@ -1032,57 +1042,56 @@ class PlanParserSuite extends AnalysisTest { errorClass = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'b'", "hint" -> "")) - comparePlans( - parsePlan("SELECT /*+ HINT */ * FROM t"), + assertEqual( + "SELECT /*+ HINT */ * FROM t", UnresolvedHint("HINT", Seq.empty, table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), + assertEqual( + "SELECT /*+ BROADCASTJOIN(u) */ * FROM t", UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), + assertEqual( + "SELECT /*+ MAPJOIN(u) */ * FROM t", UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), + assertEqual( + "SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t", UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"), + assertEqual( + "SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t", UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), + assertEqual( + "SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`", UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")), table("default.t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), + assertEqual( + "SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a", UnresolvedHint("MAPJOIN", Seq($"t"), table("t").where(Literal(true)).groupBy($"a")($"a")).orderBy($"a".asc)) - comparePlans( - parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"), + assertEqual( + "SELECT /*+ COALESCE(10) */ * FROM t", UnresolvedHint("COALESCE", Seq(Literal(10)), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100)), table("t").select(star()))) - comparePlans( - parsePlan( - "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), + assertEqual( + "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t", InsertIntoStatement(table("s"), Map.empty, Nil, UnresolvedHint("REPARTITION", Seq(Literal(100)), UnresolvedHint("COALESCE", Seq(Literal(500)), UnresolvedHint("COALESCE", Seq(Literal(10)), table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) - comparePlans( - parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), + assertEqual( + "SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t", UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("REPARTITION", Seq(Literal(100)), table("t").select(star())))) @@ -1093,49 +1102,48 @@ class PlanParserSuite extends AnalysisTest { errorClass = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'+'", "hint" -> "")) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(c) */ * FROM t", UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("COALESCE", Seq(Literal(50)), table("t").select(star())))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("COALESCE", Seq(Literal(50)), table("t").select(star()))))) - comparePlans( - parsePlan( - """ - |SELECT - |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */ - |* FROM t - """.stripMargin), + assertEqual( + """ + |SELECT + |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */ + |* FROM t + """.stripMargin, UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("COALESCE", Seq(Literal(50)), UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")), table("t").select(star())))))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t", UnresolvedHint("REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t", UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(100), UnresolvedAttribute("c")), table("t").select(star()))) } @@ -1758,4 +1766,15 @@ class PlanParserSuite extends AnalysisTest { parsePlan("SELECT * FROM a LIMIT ?"), table("a").select(star()).limit(PosParameter(22))) } + + test("SPARK-45189: Creating UnresolvedRelation from TableIdentifier should include the" + + " catalog field") { + val tableId = TableIdentifier("t", Some("db"), Some("cat")) + val unresolvedRelation = UnresolvedRelation(tableId) + assert(unresolvedRelation.multipartIdentifier == Seq("cat", "db", "t")) + val unresolvedRelation2 = UnresolvedRelation(tableId, CaseInsensitiveStringMap.empty, true) + assert(unresolvedRelation2.multipartIdentifier == Seq("cat", "db", "t")) + assert(unresolvedRelation2.options == CaseInsensitiveStringMap.empty) + assert(unresolvedRelation2.isStreaming) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala index c680e08c1c832..3012ef6f1544d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala @@ -39,7 +39,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -59,7 +59,7 @@ class UnpivotParserSuite extends AnalysisTest { sql, Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), Some(Seq(Some("A"), None)), "col", Seq("val"), @@ -76,7 +76,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT ((val1, val2) FOR col in ((a, b), (c, d)))", Unpivot( None, - Some(Seq(Seq($"a", $"b").map(UnresolvedAlias(_)), Seq($"c", $"d").map(UnresolvedAlias(_)))), + Some(Seq(Seq($"a", $"b"), Seq($"c", $"d"))), None, "col", Seq("val1", "val2"), @@ -96,10 +96,7 @@ class UnpivotParserSuite extends AnalysisTest { sql, Unpivot( None, - Some(Seq( - Seq($"a", $"b").map(UnresolvedAlias(_)), - Seq($"c", $"d").map(UnresolvedAlias(_)) - )), + Some(Seq(Seq($"a", $"b"), Seq($"c", $"d"))), Some(Seq(Some("first"), None)), "col", Seq("val1", "val2"), @@ -132,7 +129,7 @@ class UnpivotParserSuite extends AnalysisTest { sql, Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -169,7 +166,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT EXCLUDE NULLS (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -184,7 +181,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT INCLUDE NULLS (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -199,7 +196,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) JOIN t2", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -211,7 +208,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 JOIN t2 UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -224,7 +221,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -239,7 +236,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)), t2", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -251,7 +248,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -267,7 +264,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -282,7 +279,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -296,7 +293,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -311,7 +308,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -326,13 +323,13 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 33e521eb65a57..d1276615c5faa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -176,6 +176,22 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { expectedStatsCboOff = rangeStats, extraConfig) } +test("range with invalid long value") { + val numElements = BigInt(Long.MaxValue) - BigInt(Long.MinValue) + val range = Range(Long.MinValue, Long.MaxValue, 1, None) + val rangeAttrs = AttributeMap(range.output.map(attr => + (attr, ColumnStat( + distinctCount = Some(numElements), + nullCount = Some(0), + maxLen = Some(LongType.defaultSize), + avgLen = Some(LongType.defaultSize))))) + val rangeStats = Statistics( + sizeInBytes = numElements * 8, + rowCount = Some(numElements), + attributeStats = rangeAttrs) + checkStats(range, rangeStats, rangeStats) +} + test("windows") { val windows = plan.window(Seq(min(attribute).as("sum_attr")), Seq(attribute), Nil) val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala index c634c5b739b8f..3de331f90a6d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala @@ -55,6 +55,12 @@ class NumberConverterSuite extends SparkFunSuite { checkConv("-10", 11, 7, "45012021522523134134555") } + test("SPARK-44973: conv must allocate enough space for all digits plus negative sign") { + checkConv(s"${Long.MinValue}", 10, -2, BigInt(Long.MinValue).toString(2)) + checkConv((BigInt(Long.MaxValue) + 1).toString(16), 16, -2, BigInt(Long.MinValue).toString(2)) + checkConv(BigInt(Long.MinValue).toString(16), 16, -2, BigInt(Long.MinValue).toString(2)) + } + test("byte to binary") { checkToBinary(0.toByte) checkToBinary(1.toByte) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala index eb173bc7f8c87..8ff6c7b2ad705 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala @@ -297,6 +297,40 @@ class TimestampFormatterSuite extends DatetimeFormatterSuite { } } + test("SPARK-49065: rebasing in legacy formatters/parsers with non-default time zone") { + val defaultTimeZone = LA + withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> LegacyBehaviorPolicy.LEGACY.toString) { + outstandingZoneIds.foreach { zoneId => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimeZone.getId) { + withDefaultTimeZone(defaultTimeZone) { + withClue(s"zoneId = ${zoneId.getId}") { + val formatters = LegacyDateFormats.values.toSeq.map { legacyFormat => + TimestampFormatter( + TimestampFormatter.defaultPattern(), + zoneId, + TimestampFormatter.defaultLocale, + legacyFormat, + isParsing = false) + } :+ TimestampFormatter.getFractionFormatter(zoneId) + formatters.foreach { formatter => + assert(microsToInstant(formatter.parse("1000-01-01 01:02:03")) + .atZone(zoneId) + .toLocalDateTime === LocalDateTime.of(1000, 1, 1, 1, 2, 3)) + + assert(formatter.format( + LocalDateTime.of(1000, 1, 1, 1, 2, 3).atZone(zoneId).toInstant) === + "1000-01-01 01:02:03") + assert(formatter.format(instantToMicros( + LocalDateTime.of(1000, 1, 1, 1, 2, 3) + .atZone(zoneId).toInstant)) === "1000-01-01 01:02:03") + } + } + } + } + } + } + } + test("parsing hour with various patterns") { def createFormatter(pattern: String): TimestampFormatter = { // Use `SIMPLE_DATE_FORMAT`, so that the legacy parser also fails with invalid value range. @@ -502,9 +536,20 @@ class TimestampFormatterSuite extends DatetimeFormatterSuite { assert(fastFormatter.parseOptional("2023-12-31 23:59:59.9990").contains(1704067199999000L)) assert(fastFormatter.parseOptional("abc").isEmpty) + assert(fastFormatter.parseOptional("23012150952").isEmpty) assert(simpleFormatter.parseOptional("2023-12-31 23:59:59.9990").contains(1704067208990000L)) assert(simpleFormatter.parseOptional("abc").isEmpty) + assert(simpleFormatter.parseOptional("23012150952").isEmpty) + } + test("SPARK-45424: do not return optional parse results when only prefix match") { + val formatter = new Iso8601TimestampFormatter( + "yyyy-MM-dd HH:mm:ss", + locale = DateFormatter.defaultLocale, + legacyFormat = LegacyDateFormats.SIMPLE_DATE_FORMAT, + isParsing = true, zoneId = DateTimeTestUtils.LA) + assert(formatter.parseOptional("9999-12-31 23:59:59.999").isEmpty) + assert(formatter.parseWithoutTimeZoneOptional("9999-12-31 23:59:59.999", true).isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index a0a4d8bdee9f5..a309db341d8e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -619,9 +619,9 @@ class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage } /** - * Theoretically, [[InternalRow]] returned by [[HasPartitionKey#partitionKey()]] + * Theoretically, `InternalRow` returned by `HasPartitionKey#partitionKey()` * does not need to implement equal and hashcode methods. - * But [[GenericInternalRow]] implements equals and hashcode methods already. Here we override it + * But `GenericInternalRow` implements equals and hashcode methods already. Here we override it * to simulate that it has not been implemented to verify codes correctness. */ case class PartitionInternalRow(keys: Array[Any]) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7313ee5c41340..62d33dbfc2d41 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index b6184baa2e0ed..5bfe22450f36b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -31,12 +31,11 @@ import org.apache.orc.TypeDescription; import org.apache.orc.mapred.OrcInputFormat; +import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns; import org.apache.spark.sql.execution.datasources.orc.OrcShimUtils.VectorizedRowBatchWrap; -import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ConstantColumnVector; -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.*; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -73,11 +72,14 @@ public class OrcColumnarBatchReader extends RecordReader { @VisibleForTesting public ColumnarBatch columnarBatch; + private final MemoryMode memoryMode; + // The wrapped ORC column vectors. private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; - public OrcColumnarBatchReader(int capacity) { + public OrcColumnarBatchReader(int capacity, MemoryMode memoryMode) { this.capacity = capacity; + this.memoryMode = memoryMode; } @@ -177,7 +179,12 @@ public void initBatch( int colId = requestedDataColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + final WritableColumnVector missingCol; + if (memoryMode == MemoryMode.OFF_HEAP) { + missingCol = new OffHeapColumnVector(capacity, dt); + } else { + missingCol = new OnHeapColumnVector(capacity, dt); + } // Check if the missing column has an associated default value in the schema metadata. // If so, fill the corresponding column vector with the value. Object defaultValue = ResolveDefaultColumns.existenceDefaultValues(requiredSchema)[i]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 15d58f0c7572a..8c4fe20853879 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -109,22 +109,32 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa // For unsigned int64, it stores as plain signed int64 in Parquet when dictionary // fallbacks. We read them as decimal values. return new UnsignedLongUpdater(); - } else if (isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MICROS)) { - validateTimestampType(sparkType); + } else if (sparkType == DataTypes.TimestampType && + isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MICROS)) { if ("CORRECTED".equals(datetimeRebaseMode)) { return new LongUpdater(); } else { boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); return new LongWithRebaseUpdater(failIfRebase, datetimeRebaseTz); } - } else if (isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MILLIS)) { - validateTimestampType(sparkType); + } else if (sparkType == DataTypes.TimestampType && + isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MILLIS)) { if ("CORRECTED".equals(datetimeRebaseMode)) { return new LongAsMicrosUpdater(); } else { final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); return new LongAsMicrosRebaseUpdater(failIfRebase, datetimeRebaseTz); } + } else if (sparkType == DataTypes.TimestampNTZType && + isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MICROS)) { + validateTimestampNTZType(); + // TIMESTAMP_NTZ is a new data type and has no legacy files that need to do rebase. + return new LongUpdater(); + } else if (sparkType == DataTypes.TimestampNTZType && + isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit.MILLIS)) { + validateTimestampNTZType(); + // TIMESTAMP_NTZ is a new data type and has no legacy files that need to do rebase. + return new LongAsMicrosUpdater(); } else if (sparkType instanceof DayTimeIntervalType) { return new LongUpdater(); } @@ -194,12 +204,11 @@ boolean isTimestampTypeMatched(LogicalTypeAnnotation.TimeUnit unit) { ((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).getUnit() == unit; } - void validateTimestampType(DataType sparkType) { + private void validateTimestampNTZType() { assert(logicalTypeAnnotation instanceof TimestampLogicalTypeAnnotation); - // Throw an exception if the Parquet type is TimestampLTZ and the Catalyst type is TimestampNTZ. + // Throw an exception if the Parquet type is TimestampLTZ as the Catalyst type is TimestampNTZ. // This is to avoid mistakes in reading the timestamp values. - if (((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).isAdjustedToUTC() && - sparkType == DataTypes.TimestampNTZType) { + if (((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).isAdjustedToUTC()) { convertErrorForTimestampNTZ("int64 time(" + logicalTypeAnnotation + ")"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 4f2b65f36120a..7dfea3d980c55 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -140,12 +140,15 @@ public void initialize( // in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); - if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { - @SuppressWarnings("unchecked") - AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); - intAccum.add(fileReader.getRowGroups().size()); - } + taskContext.taskMetrics().withExternalAccums((accums) -> { + Option> accu = accums.lastOption(); + if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { + @SuppressWarnings("unchecked") + AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); + intAccum.add(fileReader.getRowGroups().size()); + } + return null; + }); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 46f241d92e6bd..122f775c2b0e0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -215,7 +215,9 @@ public byte[] getBytes(int rowId, int count) { Platform.copyMemory(null, data + rowId, array, Platform.BYTE_ARRAY_OFFSET, count); } else { for (int i = 0; i < count; i++) { - array[i] = getByte(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -276,7 +278,9 @@ public short[] getShorts(int rowId, int count) { Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); } else { for (int i = 0; i < count; i++) { - array[i] = getShort(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -342,7 +346,9 @@ public int[] getInts(int rowId, int count) { Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); } else { for (int i = 0; i < count; i++) { - array[i] = getInt(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -420,7 +426,9 @@ public long[] getLongs(int rowId, int count) { Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); } else { for (int i = 0; i < count; i++) { - array[i] = getLong(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -484,7 +492,9 @@ public float[] getFloats(int rowId, int count) { Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); } else { for (int i = 0; i < count; i++) { - array[i] = getFloat(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -550,7 +560,9 @@ public double[] getDoubles(int rowId, int count) { count * 8L); } else { for (int i = 0; i < count; i++) { - array[i] = getDouble(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i)); + } } } return array; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index b717323753e87..160441e7583ed 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -213,7 +213,9 @@ public byte[] getBytes(int rowId, int count) { System.arraycopy(byteData, rowId, array, 0, count); } else { for (int i = 0; i < count; i++) { - array[i] = getByte(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -273,7 +275,9 @@ public short[] getShorts(int rowId, int count) { System.arraycopy(shortData, rowId, array, 0, count); } else { for (int i = 0; i < count; i++) { - array[i] = getShort(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -334,7 +338,9 @@ public int[] getInts(int rowId, int count) { System.arraycopy(intData, rowId, array, 0, count); } else { for (int i = 0; i < count; i++) { - array[i] = getInt(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -406,7 +412,9 @@ public long[] getLongs(int rowId, int count) { System.arraycopy(longData, rowId, array, 0, count); } else { for (int i = 0; i < count; i++) { - array[i] = getLong(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -463,7 +471,9 @@ public float[] getFloats(int rowId, int count) { System.arraycopy(floatData, rowId, array, 0, count); } else { for (int i = 0; i < count; i++) { - array[i] = getFloat(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i)); + } } } return array; @@ -522,7 +532,9 @@ public double[] getDoubles(int rowId, int count) { System.arraycopy(doubleData, rowId, array, 0, count); } else { for (int i = 0; i < count; i++) { - array[i] = getDouble(rowId + i); + if (!isNullAt(rowId + i)) { + array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i)); + } } } return array; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4de6b944bc868..4c0c750246f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -28,14 +28,17 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogManager, CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.catalog.TableWritePrivilege +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -448,7 +451,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val table = catalog.asTableCatalog.loadTable(ident) match { + val table = catalog.asTableCatalog.loadTable(ident, getWritePrivileges.toSet.asJava) match { case _: V1Table => return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption)) case t => @@ -479,7 +482,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(tableIdent: TableIdentifier): Unit = { runCommand(df.sparkSession) { InsertIntoStatement( - table = UnresolvedRelation(tableIdent), + table = UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges), partitionSpec = Map.empty[String, Option[String]], Nil, query = df.logicalPlan, @@ -488,6 +491,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } + private def getWritePrivileges: Seq[TableWritePrivilege] = mode match { + case SaveMode.Overwrite => Seq(INSERT, DELETE) + case _ => Seq(INSERT) + } + private def getBucketSpec: Option[BucketSpec] = { if (sortColumnNames.isDefined && numBuckets.isEmpty) { throw QueryCompilationErrors.sortByWithoutBucketingError() @@ -557,7 +565,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val session = df.sparkSession - val canUseV2 = lookupV2Provider().isDefined + val canUseV2 = lookupV2Provider().isDefined || (df.sparkSession.sessionState.conf.getConf( + SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined && + !df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME) + .isInstanceOf[CatalogExtension]) session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => @@ -578,7 +589,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def saveAsTable( catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = { - val tableOpt = try Option(catalog.loadTable(ident)) catch { + val tableOpt = try Option(catalog.loadTable(ident, getWritePrivileges.toSet.asJava)) catch { case _: NoSuchTableException => None } @@ -639,7 +650,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val catalog = df.sparkSession.sessionState.catalog val qualifiedIdent = catalog.qualifyIdentifier(tableIdent) val tableExists = catalog.tableExists(qualifiedIdent) - val tableName = qualifiedIdent.unquotedString (tableExists, mode) match { case (true, SaveMode.Ignore) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 7ca9c7ef71d67..09d884af05b18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec} +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution @@ -146,7 +147,9 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) */ @throws(classOf[NoSuchTableException]) def append(): Unit = { - val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap) + val append = AppendData.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), + logicalPlan, options.toMap) runCommand(append) } @@ -163,7 +166,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { val overwrite = OverwriteByExpression.byName( - UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap) + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, condition.expr, options.toMap) runCommand(overwrite) } @@ -183,7 +187,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwritePartitions(): Unit = { val dynamicOverwrite = OverwritePartitionsDynamic.byName( - UnresolvedRelation(tableName), logicalPlan, options.toMap) + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, options.toMap) runCommand(dynamicOverwrite) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fd8421fa096cc..a65aff4b228ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -774,8 +774,7 @@ class Dataset[T] private[sql]( val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) require(!IntervalUtils.isNegative(parsedDelay), s"delay threshold ($delayThreshold) should not be negative.") - EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) } /** @@ -1230,7 +1229,9 @@ class Dataset[T] private[sql]( JoinHint.NONE)).analyzed.asInstanceOf[Join] implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + ExpressionEncoder + .tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = true) + .asInstanceOf[Encoder[(T, U)]] withTypedPlan(JoinWith.typedJoinWith( joined, @@ -2189,7 +2190,7 @@ class Dataset[T] private[sql]( */ @varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { - CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan) + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) } /** @@ -3013,19 +3014,8 @@ class Dataset[T] private[sql]( * @since 3.4.0 */ @scala.annotation.varargs - def drop(col: Column, cols: Column*): DataFrame = { - val allColumns = col +: cols - val expressions = (for (col <- allColumns) yield col match { - case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted( - u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) - case Column(expr: Expression) => expr - }) - val attrs = this.logicalPlan.output - val colsAfterDrop = attrs.filter { attr => - expressions.forall(expression => !attr.semanticEquals(expression)) - }.map(attr => Column(attr)) - select(colsAfterDrop : _*) + def drop(col: Column, cols: Column*): DataFrame = withPlan { + DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan) } /** @@ -4046,13 +4036,12 @@ class Dataset[T] private[sql]( new DataStreamWriter[T](this) } - /** * Returns the content of the Dataset as a Dataset of JSON strings. * @since 2.0.0 */ def toJSON: Dataset[String] = { - val rowSchema = this.schema + val rowSchema = exprEnc.schema val sessionLocalTimeZone = sparkSession.sessionState.conf.sessionLocalTimeZone mapPartitions { iter => val writer = new CharArrayWriter() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index b7c86ab7de6b4..677dba0082575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} *

  • Customized Parser.
  • *
  • (External) Catalog listeners.
  • *
  • Columnar Rules.
  • + *
  • Adaptive Query Post Planner Strategy Rules.
  • *
  • Adaptive Query Stage Preparation Rules.
  • *
  • Adaptive Query Execution Runtime Optimizer Rules.
  • *
  • Adaptive Query Stage Optimizer Rules.
  • @@ -112,10 +113,13 @@ class SparkSessionExtensions { type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder) type ColumnarRuleBuilder = SparkSession => ColumnarRule + type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan] type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan] private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private[this] val queryPostPlannerStrategyRuleBuilders = + mutable.Buffer.empty[QueryPostPlannerStrategyBuilder] private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder] private[this] val queryStageOptimizerRuleBuilders = @@ -128,6 +132,14 @@ class SparkSessionExtensions { columnarRuleBuilders.map(_.apply(session)).toSeq } + /** + * Build the override rules for the query post planner strategy phase of adaptive query execution. + */ + private[sql] def buildQueryPostPlannerStrategyRules( + session: SparkSession): Seq[Rule[SparkPlan]] = { + queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq + } + /** * Build the override rules for the query stage preparation phase of adaptive query execution. */ @@ -156,6 +168,15 @@ class SparkSessionExtensions { columnarRuleBuilders += builder } + /** + * Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so + * it can get the whole plan before injecting exchanges. + * Note, these rules can only be applied within AQE. + */ + def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = { + queryPostPlannerStrategyRuleBuilders += builder + } + /** * Inject a rule that can override the query stage preparation phase of adaptive query * execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index d8e19c994c59e..2a92dc59f3871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, ResolveDefaultColumns => DefaultCols} +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, CharVarcharUtils, ResolveDefaultColumns => DefaultCols} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, LookupCatalog, SupportsNamespaces, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogManager, CatalogPlugin, CatalogV2Util, LookupCatalog, SupportsNamespaces, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.command._ @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1, import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.connector.V1Function -import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, StructType} /** * Converts resolved v2 commands to v1 if the catalog is the session catalog. Since the v2 commands @@ -66,7 +66,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) throw QueryCompilationErrors.unsupportedTableOperationError(ident, "REPLACE COLUMNS") case a @ AlterColumn(ResolvedTable(catalog, ident, table: V1Table, _), _, _, _, _, _, _) - if isSessionCatalog(catalog) => + if supportsV1Command(catalog) => if (a.column.name.length > 1) { throw QueryCompilationErrors.unsupportedTableOperationError( catalog, ident, "ALTER COLUMN with qualified column") @@ -84,7 +84,11 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) val colName = a.column.name(0) val dataType = a.dataType.getOrElse { table.schema.findNestedField(Seq(colName), resolver = conf.resolver) - .map(_._2.dataType) + .map { + case (_, StructField(_, st: StringType, _, metadata)) => + CharVarcharUtils.getRawType(metadata).getOrElse(st) + case (_, field) => field.dataType + } .getOrElse { throw QueryCompilationErrors.alterColumnCannotFindColumnInV1TableError( quoteIfNeeded(colName), table) @@ -117,13 +121,13 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case UnsetViewProperties(ResolvedViewIdentifier(ident), keys, ifExists) => AlterTableUnsetPropertiesCommand(ident, keys, ifExists, isView = true) - case DescribeNamespace(DatabaseInSessionCatalog(db), extended, output) if conf.useV1Command => + case DescribeNamespace(ResolvedV1Database(db), extended, output) if conf.useV1Command => DescribeDatabaseCommand(db, extended, output) - case SetNamespaceProperties(DatabaseInSessionCatalog(db), properties) if conf.useV1Command => + case SetNamespaceProperties(ResolvedV1Database(db), properties) if conf.useV1Command => AlterDatabasePropertiesCommand(db, properties) - case SetNamespaceLocation(DatabaseInSessionCatalog(db), location) if conf.useV1Command => + case SetNamespaceLocation(ResolvedV1Database(db), location) if conf.useV1Command => if (StringUtils.isEmpty(location)) { throw QueryExecutionErrors.invalidEmptyLocationError(location) } @@ -218,7 +222,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case DropTable(ResolvedIdentifier(FakeSystemCatalog, ident), _, _) => DropTempViewCommand(ident) - case DropView(ResolvedV1Identifier(ident), ifExists) => + case DropView(ResolvedIdentifierInSessionCatalog(ident), ifExists) => DropTableCommand(ident, ifExists, isView = true, purge = false) case DropView(r @ ResolvedIdentifier(catalog, ident), _) => @@ -237,14 +241,14 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } CreateDatabaseCommand(name, c.ifNotExists, location, comment, newProperties) - case d @ DropNamespace(DatabaseInSessionCatalog(db), _, _) if conf.useV1Command => + case d @ DropNamespace(ResolvedV1Database(db), _, _) if conf.useV1Command => DropDatabaseCommand(db, d.ifExists, d.cascade) - case ShowTables(DatabaseInSessionCatalog(db), pattern, output) if conf.useV1Command => + case ShowTables(ResolvedV1Database(db), pattern, output) if conf.useV1Command => ShowTablesCommand(Some(db), pattern, output) case ShowTableExtended( - DatabaseInSessionCatalog(db), + ResolvedV1Database(db), pattern, partitionSpec @ (None | Some(UnresolvedPartitionSpec(_, _))), output) => @@ -265,16 +269,26 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) AnalyzePartitionCommand(ident, partitionSpec, noScan) } - case AnalyzeTables(DatabaseInSessionCatalog(db), noScan) => + case AnalyzeTables(ResolvedV1Database(db), noScan) => AnalyzeTablesCommand(Some(db), noScan) case AnalyzeColumn(ResolvedV1TableOrViewIdentifier(ident), columnNames, allColumns) => AnalyzeColumnCommand(ident, columnNames, allColumns) - case RepairTable(ResolvedV1TableIdentifier(ident), addPartitions, dropPartitions) => + // V2 catalog doesn't support REPAIR TABLE yet, we must use v1 command here. + case RepairTable( + ResolvedV1TableIdentifierInSessionCatalog(ident), + addPartitions, + dropPartitions) => RepairTableCommand(ident, addPartitions, dropPartitions) - case LoadData(ResolvedV1TableIdentifier(ident), path, isLocal, isOverwrite, partition) => + // V2 catalog doesn't support LOAD DATA yet, we must use v1 command here. + case LoadData( + ResolvedV1TableIdentifierInSessionCatalog(ident), + path, + isLocal, + isOverwrite, + partition) => LoadDataCommand( ident, path, @@ -293,7 +307,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) if conf.useV1Command => ShowCreateTableCommand(ident, output) case ShowCreateTable(ResolvedTable(catalog, _, table: V1Table, _), _, output) - if isSessionCatalog(catalog) && DDLUtils.isHiveTable(table.catalogTable) => + if supportsV1Command(catalog) && DDLUtils.isHiveTable(table.catalogTable) => ShowCreateTableCommand(table.catalogTable.identifier, output) case TruncateTable(ResolvedV1TableIdentifier(ident)) => @@ -322,7 +336,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } ShowColumnsCommand(db, v1TableName, output) - case RecoverPartitions(ResolvedV1TableIdentifier(ident)) => + // V2 catalog doesn't support RECOVER PARTITIONS yet, we must use v1 command here. + case RecoverPartitions(ResolvedV1TableIdentifierInSessionCatalog(ident)) => RepairTableCommand( ident, enableAddPartitions = true, @@ -350,8 +365,9 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) purge, retainData = false) + // V2 catalog doesn't support setting serde properties yet, we must use v1 command here. case SetTableSerDeProperties( - ResolvedV1TableIdentifier(ident), + ResolvedV1TableIdentifierInSessionCatalog(ident), serdeClassName, serdeProperties, partitionSpec) => @@ -361,13 +377,20 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) serdeProperties, partitionSpec) - case SetTableLocation(ResolvedV1TableIdentifier(ident), partitionSpec, location) => - AlterTableSetLocationCommand(ident, partitionSpec, location) + case SetTableLocation(ResolvedV1TableIdentifier(ident), None, location) => + AlterTableSetLocationCommand(ident, None, location) + + // V2 catalog doesn't support setting partition location yet, we must use v1 command here. + case SetTableLocation( + ResolvedV1TableIdentifierInSessionCatalog(ident), + Some(partitionSpec), + location) => + AlterTableSetLocationCommand(ident, Some(partitionSpec), location) case AlterViewAs(ResolvedViewIdentifier(ident), originalText, query) => AlterViewAsCommand(ident, originalText, query) - case CreateView(ResolvedV1Identifier(ident), userSpecifiedColumns, comment, + case CreateView(ResolvedIdentifierInSessionCatalog(ident), userSpecifiedColumns, comment, properties, originalText, child, allowExisting, replace) => CreateViewCommand( name = ident, @@ -385,7 +408,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case ShowViews(ns: ResolvedNamespace, pattern, output) => ns match { - case DatabaseInSessionCatalog(db) => ShowViewsCommand(db, pattern, output) + case ResolvedDatabaseInSessionCatalog(db) => ShowViewsCommand(db, pattern, output) case _ => throw QueryCompilationErrors.missingCatalogAbilityError(ns.catalog, "views") } @@ -408,7 +431,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "functions") } - case ShowFunctions(DatabaseInSessionCatalog(db), userScope, systemScope, pattern, output) => + case ShowFunctions( + ResolvedDatabaseInSessionCatalog(db), userScope, systemScope, pattern, output) => ShowFunctionsCommand(db, pattern, userScope, systemScope, output) case DropFunction(ResolvedPersistentFunc(catalog, identifier, _), ifExists) => @@ -429,7 +453,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "REFRESH FUNCTION") } - case CreateFunction(ResolvedV1Identifier(ident), className, resources, ifExists, replace) => + case CreateFunction( + ResolvedIdentifierInSessionCatalog(ident), className, resources, ifExists, replace) => CreateFunctionCommand( FunctionIdentifier(ident.table, ident.database, ident.catalog), className, @@ -563,6 +588,14 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } object ResolvedV1TableIdentifier { + def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match { + case ResolvedTable(catalog, _, t: V1Table, _) if supportsV1Command(catalog) => + Some(t.catalogTable.identifier) + case _ => None + } + } + + object ResolvedV1TableIdentifierInSessionCatalog { def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match { case ResolvedTable(catalog, _, t: V1Table, _) if isSessionCatalog(catalog) => Some(t.catalogTable.identifier) @@ -579,6 +612,18 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } object ResolvedV1Identifier { + def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match { + case ResolvedIdentifier(catalog, ident) if supportsV1Command(catalog) => + if (ident.namespace().length != 1) { + throw QueryCompilationErrors.requiresSinglePartNamespaceError(ident.namespace()) + } + Some(TableIdentifier(ident.name, Some(ident.namespace.head), Some(catalog.name))) + case _ => None + } + } + + // Use this object to help match commands that do not have a v2 implementation. + object ResolvedIdentifierInSessionCatalog{ def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match { case ResolvedIdentifier(catalog, ident) if isSessionCatalog(catalog) => if (ident.namespace().length != 1) { @@ -610,7 +655,21 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } } - private object DatabaseInSessionCatalog { + private object ResolvedV1Database { + def unapply(resolved: ResolvedNamespace): Option[String] = resolved match { + case ResolvedNamespace(catalog, _) if !supportsV1Command(catalog) => None + case ResolvedNamespace(_, Seq()) => + throw QueryCompilationErrors.databaseFromV1SessionCatalogNotSpecifiedError() + case ResolvedNamespace(_, Seq(dbName)) => Some(dbName) + case _ => + assert(resolved.namespace.length > 1) + throw QueryCompilationErrors.nestedDatabaseUnsupportedByV1SessionCatalogError( + resolved.namespace.map(quoteIfNeeded).mkString(".")) + } + } + + // Use this object to help match commands that do not have a v2 implementation. + private object ResolvedDatabaseInSessionCatalog { def unapply(resolved: ResolvedNamespace): Option[String] = resolved match { case ResolvedNamespace(catalog, _) if !isSessionCatalog(catalog) => None case ResolvedNamespace(_, Seq()) => @@ -625,11 +684,17 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) private object DatabaseNameInSessionCatalog { def unapply(resolved: ResolvedNamespace): Option[String] = resolved match { - case ResolvedNamespace(catalog, _) if !isSessionCatalog(catalog) => None + case ResolvedNamespace(catalog, _) if !supportsV1Command(catalog) => None case ResolvedNamespace(_, Seq(dbName)) => Some(dbName) case _ => assert(resolved.namespace.length > 1) throw QueryCompilationErrors.invalidDatabaseNameError(resolved.namespace.quoted) } } + + private def supportsV1Command(catalog: CatalogPlugin): Boolean = { + isSessionCatalog(catalog) && ( + SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty || + catalog.isInstanceOf[CatalogExtension]) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 947a5e9f383f9..c7bca751e56e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableExpression -import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType} +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType} /** * The builder to generate V2 expressions from catalyst expressions. @@ -96,45 +96,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Cast(v, dataType)) case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => generateAggregateFunc(aggregateFunction, isDistinct) - case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) - case Coalesce(children) => generateExpressionWithName("COALESCE", children) - case Greatest(children) => generateExpressionWithName("GREATEST", children) - case Least(children) => generateExpressionWithName("LEAST", children) - case Rand(child, hideSeed) => + case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate) + case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate) + case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate) + case _: Least => generateExpressionWithName("LEAST", expr, isPredicate) + case Rand(_, hideSeed) => if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) } else { - generateExpressionWithName("RAND", Seq(child)) + generateExpressionWithName("RAND", expr, isPredicate) } - case log: Logarithm => generateExpressionWithName("LOG", log.children) - case Log10(child) => generateExpressionWithName("LOG10", Seq(child)) - case Log2(child) => generateExpressionWithName("LOG2", Seq(child)) - case Log(child) => generateExpressionWithName("LN", Seq(child)) - case Exp(child) => generateExpressionWithName("EXP", Seq(child)) - case pow: Pow => generateExpressionWithName("POWER", pow.children) - case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child)) - case Floor(child) => generateExpressionWithName("FLOOR", Seq(child)) - case Ceil(child) => generateExpressionWithName("CEIL", Seq(child)) - case round: Round => generateExpressionWithName("ROUND", round.children) - case Sin(child) => generateExpressionWithName("SIN", Seq(child)) - case Sinh(child) => generateExpressionWithName("SINH", Seq(child)) - case Cos(child) => generateExpressionWithName("COS", Seq(child)) - case Cosh(child) => generateExpressionWithName("COSH", Seq(child)) - case Tan(child) => generateExpressionWithName("TAN", Seq(child)) - case Tanh(child) => generateExpressionWithName("TANH", Seq(child)) - case Cot(child) => generateExpressionWithName("COT", Seq(child)) - case Asin(child) => generateExpressionWithName("ASIN", Seq(child)) - case Asinh(child) => generateExpressionWithName("ASINH", Seq(child)) - case Acos(child) => generateExpressionWithName("ACOS", Seq(child)) - case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child)) - case Atan(child) => generateExpressionWithName("ATAN", Seq(child)) - case Atanh(child) => generateExpressionWithName("ATANH", Seq(child)) - case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children) - case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child)) - case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child)) - case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child)) - case Signum(child) => generateExpressionWithName("SIGN", Seq(child)) - case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children) + case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate) + case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate) + case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate) + case _: Log => generateExpressionWithName("LN", expr, isPredicate) + case _: Exp => generateExpressionWithName("EXP", expr, isPredicate) + case _: Pow => generateExpressionWithName("POWER", expr, isPredicate) + case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate) + case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate) + case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate) + case _: Round => generateExpressionWithName("ROUND", expr, isPredicate) + case _: Sin => generateExpressionWithName("SIN", expr, isPredicate) + case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate) + case _: Cos => generateExpressionWithName("COS", expr, isPredicate) + case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate) + case _: Tan => generateExpressionWithName("TAN", expr, isPredicate) + case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate) + case _: Cot => generateExpressionWithName("COT", expr, isPredicate) + case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate) + case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate) + case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate) + case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate) + case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate) + case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate) + case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate) + case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate) + case _: ToDegrees => generateExpressionWithName("DEGREES", expr, isPredicate) + case _: ToRadians => generateExpressionWithName("RADIANS", expr, isPredicate) + case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate) + case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, isPredicate) case and: And => // AND expects predicate val l = generateExpression(and.left, true) @@ -185,57 +185,56 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { assert(v.isInstanceOf[V2Predicate]) new V2Not(v.asInstanceOf[V2Predicate]) } - case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child)) - case BitwiseNot(child) => generateExpressionWithName("~", Seq(child)) - case CaseWhen(branches, elseValue) => + case UnaryMinus(_, true) => generateExpressionWithName("-", expr, isPredicate) + case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate) + case caseWhen @ CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) - val values = branches.map(_._2).flatMap(generateExpression(_, true)) - if (conditions.length == branches.length && values.length == branches.length) { + val values = branches.map(_._2).flatMap(generateExpression(_)) + val elseExprOpt = elseValue.flatMap(generateExpression(_)) + if (conditions.length == branches.length && values.length == branches.length && + elseExprOpt.size == elseValue.size) { val branchExpressions = conditions.zip(values).flatMap { case (c, v) => Seq[V2Expression](c, v) } - if (elseValue.isDefined) { - elseValue.flatMap(generateExpression(_)).map { v => - val children = (branchExpressions :+ v).toArray[V2Expression] - // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] - new V2Predicate("CASE_WHEN", children) - } + val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression] + // The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)] + if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate("CASE_WHEN", children)) } else { - // The children looks like [condition1, value1, ..., conditionN, valueN] - Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression])) + Some(new GeneralScalarExpression("CASE_WHEN", children)) } } else { None } - case iff: If => generateExpressionWithName("CASE_WHEN", iff.children) + case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate) case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { Seq(substring.str, substring.pos) } else { substring.children } - generateExpressionWithName("SUBSTRING", children) - case Upper(child) => generateExpressionWithName("UPPER", Seq(child)) - case Lower(child) => generateExpressionWithName("LOWER", Seq(child)) + generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate) + case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate) + case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate) case BitLength(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("BIT_LENGTH", Seq(child)) + generateExpressionWithName("BIT_LENGTH", expr, isPredicate) case Length(child) if child.dataType.isInstanceOf[StringType] => - generateExpressionWithName("CHAR_LENGTH", Seq(child)) - case concat: Concat => generateExpressionWithName("CONCAT", concat.children) - case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children) - case trim: StringTrim => generateExpressionWithName("TRIM", trim.children) - case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children) - case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children) + generateExpressionWithName("CHAR_LENGTH", expr, isPredicate) + case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate) + case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate) + case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate) + case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, isPredicate) + case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, isPredicate) case overlay: Overlay => val children = if (overlay.len == Literal(-1)) { Seq(overlay.input, overlay.replace, overlay.pos) } else { overlay.children } - generateExpressionWithName("OVERLAY", children) - case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children) - case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children) - case date: TruncDate => generateExpressionWithName("TRUNC", date.children) + generateExpressionWithNameByChildren("OVERLAY", children, overlay.dataType, isPredicate) + case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, isPredicate) + case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, isPredicate) + case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate) case Second(child, _) => generateExpression(child).map(v => new V2Extract("SECOND", v)) case Minute(child, _) => @@ -268,12 +267,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Extract("WEEK", v)) case YearOfWeek(child) => generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v)) - case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children) - case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children) - case Crc32(child) => generateExpressionWithName("CRC32", Seq(child)) - case Md5(child) => generateExpressionWithName("MD5", Seq(child)) - case Sha1(child) => generateExpressionWithName("SHA1", Seq(child)) - case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children) + case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, isPredicate) + case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, isPredicate) + case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate) + case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate) + case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate) + case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate) // TODO supports other expressions case ApplyFunctionExpression(function, children) => val childrenExpressions = children.flatMap(generateExpression(_)) @@ -345,10 +344,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } private def generateExpressionWithName( - v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = { + v2ExpressionName: String, + expr: Expression, + isPredicate: Boolean): Option[V2Expression] = { + generateExpressionWithNameByChildren( + v2ExpressionName, expr.children, expr.dataType, isPredicate) + } + + private def generateExpressionWithNameByChildren( + v2ExpressionName: String, + children: Seq[Expression], + dataType: DataType, + isPredicate: Boolean): Option[V2Expression] = { val childrenExpressions = children.flatMap(generateExpression(_)) if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + if (isPredicate && dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } else { + Some(new GeneralScalarExpression( + v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index e1dcab80af307..428fe65501fb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -30,7 +30,7 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode with AliasAwareOutputExpression { final override def outputPartitioning: Partitioning = { val partitionings: Seq[Partitioning] = if (hasAlias) { - flattenPartitioning(child.outputPartitioning).flatMap { + flattenPartitioning(child.outputPartitioning).iterator.flatMap { case e: Expression => // We need unique partitionings but if the input partitioning is // `HashPartitioning(Seq(id + id))` and we have `id -> a` and `id -> b` aliases then after @@ -44,7 +44,7 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode .take(aliasCandidateLimit) .asInstanceOf[Stream[Partitioning]] case o => Seq(o) - } + }.take(aliasCandidateLimit).toSeq } else { // Filter valid partitiongs (only reference output attributes of the current plan node) val outputSet = AttributeSet(outputExpressions.map(_.toAttribute)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 064819275e004..9b79865149abd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -113,7 +113,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { planToCache: LogicalPlan, tableName: Option[String], storageLevel: StorageLevel): Unit = { - if (lookupCachedData(planToCache).nonEmpty) { + if (storageLevel == StorageLevel.NONE) { + // Do nothing for StorageLevel.NONE since it will not actually cache any data. + } else if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) @@ -400,8 +402,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { if (session.conf.get(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING)) { // Bucketed scan only has one time overhead but can have multi-times benefits in cache, // so we always do bucketed scan in a cached plan. - SparkSession.getOrCloneSessionWithConfigsOff( - session, SQLConf.AUTO_BUCKETED_SCAN_ENABLED :: Nil) + SparkSession.getOrCloneSessionWithConfigsOff(session, + SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS :: + SQLConf.AUTO_BUCKETED_SCAN_ENABLED :: Nil) } else { SparkSession.getOrCloneSessionWithConfigsOff(session, forceDisableConfigs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala index 5f38278d2dc67..45e3e41ab053d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala @@ -81,6 +81,8 @@ case class CommandResultExec( unsafeRows } + override def executeToIterator(): Iterator[InternalRow] = unsafeRows.iterator + override def executeTake(limit: Int): Array[InternalRow] = { val taken = unsafeRows.take(limit) longMetric("numOutputRows").add(taken.size) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 3dcf0efaadd8f..3b49abcb1a866 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -150,6 +150,13 @@ case class LogicalRDD( } override lazy val constraints: ExpressionSet = originConstraints.getOrElse(ExpressionSet()) + // Subqueries can have non-deterministic results even when they only contain deterministic + // expressions (e.g. consider a LIMIT 1 subquery without an ORDER BY). Propagating predicates + // containing a subquery causes the subquery to be executed twice (as the result of the subquery + // in the checkpoint computation cannot be reused), which could result in incorrect results. + // Therefore we assume that all subqueries are non-deterministic, and we do not expose any + // constraints that contain a subquery. + .filterNot(SubqueryExpression.hasSubquery) } object LogicalRDD extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala index 3da3e646f36b0..421a963453f0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution -import java.util.Collections.newSetFromMap import java.util.IdentityHashMap -import java.util.Set import scala.collection.mutable.{ArrayBuffer, BitSet} @@ -30,6 +28,8 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} object ExplainUtils extends AdaptiveSparkPlanHelper { + def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlan.localIdMap + /** * Given a input physical plan, performs the following tasks. * 1. Computes the whole stage codegen id for current operator and records it in the @@ -75,25 +75,31 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { * Given a input physical plan, performs the following tasks. * 1. Generates the explain output for the input plan excluding the subquery plans. * 2. Generates the explain output for each subquery referenced in the plan. + * + * Note that, ideally this is a no-op as different explain actions operate on different plan, + * instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared + * plan instance across multi-queries. Add lock for this method to avoid tag race condition. */ def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = { + val prevIdMap = localIdMap.get() try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new IdentityHashMap()) + // Initialize a reference-unique id map to store generated ids, which also avoid accidental + // overwrites and to allow intentional overwriting of IDs generated in previous AQE iteration + val idMap = new IdentityHashMap[QueryPlan[_], Int]() + localIdMap.set(idMap) // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out // Exchanges as part of SPARK-42753 val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] var currentOperatorID = 0 - currentOperatorID = generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, + currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges, true) val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] getSubqueries(plan, subqueries) currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, + (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges, true) } @@ -101,9 +107,9 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { val optimizedOutExchanges = ArrayBuffer.empty[Exchange] reusedExchanges.foreach{ reused => val child = reused.child - if (!operators.contains(child)) { + if (!idMap.containsKey(child)) { optimizedOutExchanges.append(child) - currentOperatorID = generateOperatorIDs(child, currentOperatorID, operators, + currentOperatorID = generateOperatorIDs(child, currentOperatorID, idMap, reusedExchanges, false) } } @@ -140,7 +146,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { append("\n") } } finally { - removeTags(plan) + localIdMap.set(prevIdMap) } } @@ -155,13 +161,15 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { * @param plan Input query plan to process * @param startOperatorID The start value of operation id. The subsequent operations will be * assigned higher value. - * @param visited A unique set of operators visited by generateOperatorIds. The set is scoped - * at the callsite function processPlan. It serves two purpose: Firstly, it is - * used to avoid accidentally overwriting existing IDs that were generated in - * the same processPlan call. Secondly, it is used to allow for intentional ID - * overwriting as part of SPARK-42753 where an Adaptively Optimized Out Exchange - * and its subtree may contain IDs that were generated in a previous AQE - * iteration's processPlan call which would result in incorrect IDs. + * @param idMap A reference-unique map store operators visited by generateOperatorIds and its + * id. This Map is scoped at the callsite function processPlan. It serves three + * purpose: + * Firstly, it stores the QueryPlan - generated ID mapping. Secondly, it is used to + * avoid accidentally overwriting existing IDs that were generated in the same + * processPlan call. Thirdly, it is used to allow for intentional ID overwriting as + * part of SPARK-42753 where an Adaptively Optimized Out Exchange and its subtree + * may contain IDs that were generated in a previous AQE iteration's processPlan + * call which would result in incorrect IDs. * @param reusedExchanges A unique set of ReusedExchange nodes visited which will be used to * idenitfy adaptively optimized out exchanges in SPARK-42753. * @param addReusedExchanges Whether to add ReusedExchange nodes to reusedExchanges set. We set it @@ -173,7 +181,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { private def generateOperatorIDs( plan: QueryPlan[_], startOperatorID: Int, - visited: Set[QueryPlan[_]], + idMap: java.util.Map[QueryPlan[_], Int], reusedExchanges: ArrayBuffer[ReusedExchangeExec], addReusedExchanges: Boolean): Int = { var currentOperationID = startOperatorID @@ -182,36 +190,35 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { return currentOperationID } - def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { + def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent(plan, plan => { plan match { case r: ReusedExchangeExec if addReusedExchanges => reusedExchanges.append(r) case _ => } - visited.add(plan) currentOperationID += 1 - plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) - } + currentOperationID + }) plan.foreachUp { case _: WholeStageCodegenExec => case _: InputAdapter => case p: AdaptiveSparkPlanExec => - currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, visited, + currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, idMap, reusedExchanges, addReusedExchanges) if (!p.executedPlan.fastEquals(p.initialPlan)) { - currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, visited, + currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, idMap, reusedExchanges, addReusedExchanges) } setOpId(p) case p: QueryStageExec => - currentOperationID = generateOperatorIDs(p.plan, currentOperationID, visited, + currentOperationID = generateOperatorIDs(p.plan, currentOperationID, idMap, reusedExchanges, addReusedExchanges) setOpId(p) case other: QueryPlan[_] => setOpId(other) currentOperationID = other.innerChildren.foldLeft(currentOperationID) { - (curId, plan) => generateOperatorIDs(plan, curId, visited, reusedExchanges, + (curId, plan) => generateOperatorIDs(plan, curId, idMap, reusedExchanges, addReusedExchanges) } } @@ -237,7 +244,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { } def collectOperatorWithID(plan: QueryPlan[_]): Unit = { - plan.getTagValue(QueryPlan.OP_ID_TAG).foreach { id => + Option(ExplainUtils.localIdMap.get().get(plan)).foreach { id => if (collectedOperators.add(id)) operators += plan } } @@ -330,20 +337,6 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { * `operationId` tag value. */ def getOpId(plan: QueryPlan[_]): String = { - plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") - } - - def removeTags(plan: QueryPlan[_]): Unit = { - def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { - p.unsetTagValue(QueryPlan.OP_ID_TAG) - p.unsetTagValue(QueryPlan.CODEGEN_ID_TAG) - children.foreach(removeTags) - } - - plan foreach { - case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) - case p: QueryStageExec => remove(p, Seq(p.plan)) - case plan: QueryPlan[_] => remove(plan, plan.innerChildren) - } + Option(ExplainUtils.localIdMap.get().get(plan)).map(v => s"$v").getOrElse("unknown") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f6dbf5fda1816..b99361437e0d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -78,6 +78,10 @@ case class GenerateExec( // boundGenerator.terminate() should be triggered after all of the rows in the partition val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithIndexInternal { (index, iter) => + boundGenerator.foreach { + case n: Nondeterministic => n.initialize(index) + case _ => + } val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) val rows = if (requiredChildOutput.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index daeac699c2791..b4cbb61352235 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture} import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.{ErrorMessageFormat, SparkContext, SparkThrowable, SparkThrowableHelper} +import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, SPARK_EXECUTOR_PREFIX} import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.SparkSession @@ -215,7 +215,8 @@ object SQLExecution { val activeSession = sparkSession val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) - exec.submit(() => { + val artifactState = JobArtifactSet.getCurrentJobArtifactState.orNull + exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) { val originalSession = SparkSession.getActiveSession val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 70a35ea911538..6173703ef3cd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -89,7 +89,8 @@ class SparkOptimizer( InferWindowGroupLimit, LimitPushDown, LimitPushDownThroughWindow, - EliminateLimits) :+ + EliminateLimits, + ConstantFolding) :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+ Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index dfe3c67e18b1f..492d95cf3c482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -545,7 +545,7 @@ class SparkSqlAstBuilder extends AstBuilder { throw QueryParsingErrors.defineTempViewWithIfNotExistsError(ctx) } - withIdentClause(ctx.identifierReference(), ident => { + withIdentClause(ctx.identifierReference(), Seq(qPlan), (ident, otherPlans) => { val tableIdentifier = ident.asTableIdentifier if (tableIdentifier.database.isDefined) { // Temporary view names should NOT contain database prefix like "database.table" @@ -559,7 +559,7 @@ class SparkSqlAstBuilder extends AstBuilder { visitCommentSpecList(ctx.commentSpec()), properties, Option(source(ctx.query)), - qPlan, + otherPlans.head, ctx.EXISTS != null, ctx.REPLACE != null, viewType = viewType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 903565a6d591b..d851eacd5ab92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -935,7 +935,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE TABLE") case _: MergeIntoTable => throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO TABLE") - case logical.CollectMetrics(name, metrics, child) => + case logical.CollectMetrics(name, metrics, child, _) => execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) => WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index 7951a6f36b9bd..858130fae32b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -82,6 +82,13 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { case _ => false } + // A broadcast query stage can't be executed without the join operator. + // TODO: we can return the original query plan before broadcast. + override protected def canExecuteWithoutJoin(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, _: BroadcastQueryStageExec) => false + case _ => true + } + override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning( // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at // `PropagateEmptyRelationBase.commonApplyFunc` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala index 46ec91dcc0ab2..6b39ac70a62ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.adaptive import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} @@ -75,7 +76,13 @@ case class AQEShuffleReadExec private( // partitions is changed. child.outputPartitioning match { case h: HashPartitioning => - CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length)) + val partitions = partitionSpecs.map { + case CoalescedPartitionSpec(start, end, _) => CoalescedBoundary(start, end) + // Can not happend due to isCoalescedRead + case unexpected => + throw SparkException.internalError(s"Unexpected ShufflePartitionSpec: $unexpected") + } + CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, partitions)) case r: RangePartitioning => CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length)) // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala index 8391fe44f5598..ee2cd8a4953bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala @@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan * query stage * @param queryStageOptimizerRules applied to a new query stage before its execution. It makes sure * all children query stages are materialized + * @param queryPostPlannerStrategyRules applied between `plannerStrategy` and `queryStagePrepRules`, + * so it can get the whole plan before injecting exchanges. */ class AdaptiveRulesHolder( val queryStagePrepRules: Seq[Rule[SparkPlan]], val runtimeOptimizerRules: Seq[Rule[LogicalPlan]], - val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) { + val queryStageOptimizerRules: Seq[Rule[SparkPlan]], + val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) { } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 36895b17aa847..77c9696e6e295 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf @@ -159,7 +159,13 @@ case class AdaptiveSparkPlanExec( ) private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { - val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) => + val rules = if (isFinalStage && + !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { + queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) + } else { + queryStageOptimizerRules + } + val optimized = rules.foldLeft(plan) { case (latestPlan, rule) => val applied = rule.apply(latestPlan) val result = rule match { case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) => @@ -187,9 +193,19 @@ case class AdaptiveSparkPlanExec( optimized } + private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = { + applyPhysicalRules( + plan, + context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules, + Some((planChangeLogger, "AQE Query Post Planner Strategy Rules")) + ) + } + @transient val initialPlan = context.session.withActive { applyPhysicalRules( - inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations"))) + applyQueryPostPlannerStrategyRules(inputPlan), + queryStagePreparationRules, + Some((planChangeLogger, "AQE Preparations"))) } @volatile private var currentPhysicalPlan = initialPlan @@ -238,7 +254,7 @@ case class AdaptiveSparkPlanExec( // and display SQL metrics correctly. // 2. If the `QueryExecution` does not match the current execution ID, it means the execution // ID belongs to another (parent) query, and we should not call update UI in this query. - // e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`. + // e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanLike`. // // That means only the root `AdaptiveSparkPlanExec` of the main query that triggers this // query execution need to do a plan update for the UI. @@ -294,6 +310,7 @@ case class AdaptiveSparkPlanExec( }(AdaptiveSparkPlanExec.executionContext) } catch { case e: Throwable => + stage.error.set(Some(e)) cleanUpAndThrowException(Seq(e), Some(stage.id)) } } @@ -309,6 +326,7 @@ case class AdaptiveSparkPlanExec( case StageSuccess(stage, res) => stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => + stage.error.set(Some(ex)) errors.append(ex) } @@ -541,9 +559,9 @@ case class AdaptiveSparkPlanExec( } } - case i: InMemoryTableScanExec => - // There is no reuse for `InMemoryTableScanExec`, which is different from `Exchange`. If we - // hit it the first time, we should always create a new query stage. + case i: InMemoryTableScanLike => + // There is no reuse for `InMemoryTableScanLike`, which is different from `Exchange`. + // If we hit it the first time, we should always create a new query stage. val newStage = newQueryStage(i) CreateStageResult( newPlan = newStage, @@ -551,6 +569,7 @@ case class AdaptiveSparkPlanExec( newStages = Seq(newStage)) case q: QueryStageExec => + assertStageNotFailed(q) CreateStageResult(newPlan = q, allChildStagesMaterialized = q.isMaterialized, newStages = Seq.empty) @@ -588,12 +607,12 @@ case class AdaptiveSparkPlanExec( } BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized) } - case i: InMemoryTableScanExec => + case i: InMemoryTableScanLike => // Apply `queryStageOptimizerRules` so that we can reuse subquery. - // No need to apply `postStageCreationRules` for `InMemoryTableScanExec` + // No need to apply `postStageCreationRules` for `InMemoryTableScanLike` // as it's a leaf node. val newPlan = optimizeQueryStage(i, isFinalStage = false) - if (!newPlan.isInstanceOf[InMemoryTableScanExec]) { + if (!newPlan.isInstanceOf[InMemoryTableScanLike]) { throw SparkException.internalError( "Custom AQE rules cannot transform table scan node to something else.") } @@ -700,7 +719,7 @@ case class AdaptiveSparkPlanExec( val optimized = optimizer.execute(logicalPlan) val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() val newPlan = applyPhysicalRules( - sparkPlan, + applyQueryPostPlannerStrategyRules(sparkPlan), preprocessingRules ++ queryStagePreparationRules, Some((planChangeLogger, "AQE Replanning"))) @@ -763,6 +782,15 @@ case class AdaptiveSparkPlanExec( } } + private def assertStageNotFailed(stage: QueryStageExec): Unit = { + if (stage.hasFailed) { + throw stage.error.get().get match { + case fatal: SparkFatalException => fatal.throwable + case other => other + } + } + } + /** * Cancel all running stages with best effort and throw an Exception containing all stage * materialization errors and stage cancellation errors. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 34399001c726f..db4a6b7dcf2eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan, UnaryExecNode, UnionExec} import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, ShuffleExchangeLike, ShuffleOrigin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -65,9 +66,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe } } - // Sub-plans under the Union operator can be coalesced independently, so we can divide them - // into independent "coalesce groups", and all shuffle stages within each group have to be - // coalesced together. + // Sub-plans under the Union/CartesianProduct/BroadcastHashJoin/BroadcastNestedLoopJoin + // operator can be coalesced independently, so we can divide them into independent + // "coalesce groups", and all shuffle stages within each group have to be coalesced together. val coalesceGroups = collectCoalesceGroups(plan) // Divide minimum task parallelism among coalesce groups according to their data sizes. @@ -136,8 +137,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe } /** - * Gather all coalesce-able groups such that the shuffle stages in each child of a Union operator - * are in their independent groups if: + * Gather all coalesce-able groups such that the shuffle stages in each child of a + * Union/CartesianProduct/BroadcastHashJoin/BroadcastNestedLoopJoin operator are in their + * independent groups if: * 1) all leaf nodes of this child are exchange stages; and * 2) all these shuffle stages support coalescing. */ @@ -146,13 +148,16 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe Seq(collectShuffleStageInfos(r)) case unary: UnaryExecNode => collectCoalesceGroups(unary.child) case union: UnionExec => union.children.flatMap(collectCoalesceGroups) - // If not all leaf nodes are exchange query stages, it's not safe to reduce the number of - // shuffle partitions, because we may break the assumption that all children of a spark plan - // have same number of output partitions. + case join: CartesianProductExec => join.children.flatMap(collectCoalesceGroups) // Note that, `BroadcastQueryStageExec` is a valid case: // If a join has been optimized from shuffled join to broadcast join, then the one side is // `BroadcastQueryStageExec` and other side is `ShuffleQueryStageExec`. It can coalesce the // shuffle side as we do not expect broadcast exchange has same partition number. + case join: BroadcastHashJoinExec => join.children.flatMap(collectCoalesceGroups) + case join: BroadcastNestedLoopJoinExec => join.children.flatMap(collectCoalesceGroups) + // If not all leaf nodes are exchange query stages, it's not safe to reduce the number of + // shuffle partitions, because we may break the assumption that all children of a spark plan + // have same number of output partitions. case p if p.collectLeaves().forall(_.isInstanceOf[ExchangeQueryStageExec]) => val shuffleStages = collectShuffleStageInfos(p) // ShuffleExchanges introduced by repartition do not support partition number change. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index b941feb12fc05..31fe9b9ed368c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -88,6 +88,13 @@ abstract class QueryStageExec extends LeafExecNode { private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption final def isMaterialized: Boolean = resultOption.get().isDefined + @transient + @volatile + protected var _error = new AtomicReference[Option[Throwable]](None) + + def error: AtomicReference[Option[Throwable]] = _error + final def hasFailed: Boolean = _error.get().isDefined + override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning override def outputOrdering: Seq[SortOrder] = plan.outputOrdering @@ -195,6 +202,7 @@ case class ShuffleQueryStageExec( ReusedExchangeExec(newOutput, shuffle), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } @@ -247,6 +255,7 @@ case class BroadcastQueryStageExec( ReusedExchangeExec(newOutput, broadcast), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } @@ -261,7 +270,7 @@ case class BroadcastQueryStageExec( } /** - * A table cache query stage whose child is a [[InMemoryTableScanExec]]. + * A table cache query stage whose child is a [[InMemoryTableScanLike]]. * * @param id the query stage id. * @param plan the underlying plan. @@ -271,7 +280,7 @@ case class TableCacheQueryStageExec( override val plan: SparkPlan) extends QueryStageExec { @transient val inMemoryTableScan = plan match { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i case _ => throw new IllegalStateException(s"wrong plan for table cache stage:\n ${plan.treeString}") } @@ -294,5 +303,5 @@ case class TableCacheQueryStageExec( override protected def doMaterialize(): Future[Any] = future - override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() + override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala index dbed66683b017..9370b3d8d1d74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -128,8 +128,10 @@ object ShufflePartitionsUtil extends Logging { // There should be no unexpected partition specs and the start indices should be identical // across all different shuffles. - assert(partitionIndicesSeq.distinct.length == 1 && partitionIndicesSeq.head.forall(_ >= 0), - s"Invalid shuffle partition specs: $inputPartitionSpecs") + if (partitionIndicesSeq.distinct.length > 1 || partitionIndicesSeq.head.exists(_ < 0)) { + logWarning(s"Could not apply partition coalescing because of unexpected partition indices.") + return Seq.empty + } // The indices may look like [0, 1, 2, 2, 2, 3, 4, 4, 5], and the repeated `2` and `4` mean // skewed partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 86dd7984b5859..a843582e9c2c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -80,7 +80,7 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, - context: TaskContext) extends Iterator[Array[Byte]] { + context: TaskContext) extends Iterator[Array[Byte]] with AutoCloseable { protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) @@ -93,13 +93,11 @@ private[sql] object ArrowConverters extends Logging { protected val arrowWriter = ArrowWriter.create(root) Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() + close() }} override def hasNext: Boolean = rowIter.hasNext || { - root.close() - allocator.close() + close() false } @@ -124,6 +122,11 @@ private[sql] object ArrowConverters extends Logging { out.toByteArray } + + override def close(): Unit = { + root.close() + allocator.close() + } } private[sql] class ArrowBatchWithSchemaIterator( @@ -226,11 +229,19 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean): Array[Byte] = { - new ArrowBatchWithSchemaIterator( + val batches = new ArrowBatchWithSchemaIterator( Iterator.empty, schema, 0L, 0L, timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get) { override def hasNext: Boolean = true - }.next() + } + Utils.tryWithSafeFinally { + batches.next() + } { + // If taskContext is null, `batches.close()` should be called to avoid memory leak. + if (TaskContext.get() == null) { + batches.close() + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 45d006b58e879..f750a4503be16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -279,9 +279,11 @@ case class CachedRDDBuilder( cachedPlan.conf) } val cached = cb.mapPartitionsInternal { it => - TaskContext.get().addTaskCompletionListener[Unit](_ => { - materializedPartitions.add(1L) - }) + TaskContext.get().addTaskCompletionListener[Unit] { context => + if (!context.isFailed() && !context.isInterrupted()) { + materializedPartitions.add(1L) + } + } new Iterator[CachedBatch] { override def hasNext: Boolean = it.hasNext override def next(): CachedBatch = { @@ -390,20 +392,10 @@ case class InMemoryRelation( @volatile var statsOfPlanToCache: Statistics = null - - override lazy val innerChildren: Seq[SparkPlan] = { - // The cachedPlan needs to be cloned here because it does not get cloned when SparkPlan.clone is - // called. This is a problem because when the explain output is generated for - // a plan it traverses the innerChildren and modifies their TreeNode.tags. If the plan is not - // cloned, there is a thread safety issue in the case that two plans with a shared cache - // operator have explain called at the same time. The cachedPlan cannot be cloned because - // it contains stateful information so we only clone it for the purpose of generating the - // explain output. - Seq(cachedPlan.clone()) - } + override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExpressions(_, cachedPlan.output)), + copy(output = output.map(QueryPlan.normalizeExpressions(_, output)), cacheBuilder, outputOrdering) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 08244a4f84fea..5ff8bfd75f8a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec} @@ -28,11 +29,32 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.vectorized.ColumnarBatch +/** + * Common trait for all InMemoryTableScans implementations to facilitate pattern matching. + */ +trait InMemoryTableScanLike extends LeafExecNode { + + /** + * Returns whether the cache buffer is loaded + */ + def isMaterialized: Boolean + + /** + * Returns the actual cached RDD without filters and serialization of row/columnar. + */ + def baseCacheRDD(): RDD[CachedBatch] + + /** + * Returns the runtime statistics after materialization. + */ + def runtimeStatistics: Statistics +} + case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends InMemoryTableScanLike { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -167,13 +189,18 @@ case class InMemoryTableScanExec( columnarInputRDD } - def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded + override def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded /** * This method is only used by AQE which executes the actually cached RDD that without filter and * serialization of row/columnar. */ - def baseCacheRDD(): RDD[CachedBatch] = { + override def baseCacheRDD(): RDD[CachedBatch] = { relation.cacheBuilder.cachedColumnBuffers } + + /** + * Returns the runtime statistics after shuffle materialization. + */ + override def runtimeStatistics: Statistics = relation.computeStats() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index a8f7cdb260010..bb8fea71019fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -755,8 +755,10 @@ case class RepairTableCommand( val statusPar: Seq[FileStatus] = if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { // parallelize the list of partitions here, then we can have better parallelism later. + // scalastyle:off parvector val parArray = new ParVector(statuses.toVector) parArray.tasksupport = evalTaskSupport + // scalastyle:on parvector parArray.seq } else { statuses diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 3718794ea5909..b6159f92f9cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -167,7 +167,7 @@ case class CreateViewCommand( } } else { // Create the view if it doesn't exist. - catalog.createTable(prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) + catalog.createTable(prepareTable(sparkSession, analyzedPlan), ignoreIfExists = allowExisting) } Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala index b5bf337a5a2e6..141767135a509 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{CharType, Metadata, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -66,9 +67,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols) }) } - paddingForStringComparison(newPlan) + paddingForStringComparison(newPlan, padCharCol = false) } else { - paddingForStringComparison(plan) + paddingForStringComparison( + plan, padCharCol = !conf.getConf(SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE)) } } @@ -90,7 +92,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def paddingForStringComparison(plan: LogicalPlan): LogicalPlan = { + private def paddingForStringComparison(plan: LogicalPlan, padCharCol: Boolean): LogicalPlan = { plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) { case operator => operator.transformExpressionsUpWithPruning( _.containsAnyPattern(BINARY_COMPARISON, IN)) { @@ -99,12 +101,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { // String literal is treated as char type when it's compared to a char type column. // We should pad the shorter one to the longer length. case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren => b.withNewChildren(newChildren) }.getOrElse(b) case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren => b.withNewChildren(newChildren.reverse) }.getOrElse(b) @@ -117,9 +119,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { val literalCharLengths = literalChars.map(_.numChars()) val targetLen = (length +: literalCharLengths).max Some(i.copy( - value = addPadding(e, length, targetLen), + value = addPadding(e, length, targetLen, alwaysPad = padCharCol), list = list.zip(literalCharLengths).map { - case (lit, charLength) => addPadding(lit, charLength, targetLen) + case (lit, charLength) => + addPadding(lit, charLength, targetLen, alwaysPad = false) } ++ nulls.map(Literal.create(_, StringType)))) case _ => None }.getOrElse(i) @@ -134,7 +137,8 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { case (_, _: OuterReference) => Seq(right) case _ => Nil } - val newChildren = CharVarcharUtils.addPaddingInStringComparison(Seq(left, right)) + val newChildren = CharVarcharUtils.addPaddingInStringComparison( + Seq(left, right), padCharCol) if (outerRefs.nonEmpty) { b.withNewChildren(newChildren.map(_.transform { case a: Attribute if outerRefs.exists(_.semanticEquals(a)) => OuterReference(a) @@ -145,7 +149,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { case i @ In(e @ AttrOrOuterRef(attr), list) if list.forall(_.isInstanceOf[Attribute]) => val newChildren = CharVarcharUtils.addPaddingInStringComparison( - attr +: list.map(_.asInstanceOf[Attribute])) + attr +: list.map(_.asInstanceOf[Attribute]), padCharCol) if (e.isInstanceOf[OuterReference]) { i.copy( value = newChildren.head.transform { @@ -162,6 +166,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { private def padAttrLitCmp( expr: Expression, metadata: Metadata, + padCharCol: Boolean, lit: Expression): Option[Seq[Expression]] = { if (expr.dataType == StringType) { CharVarcharUtils.getRawType(metadata).flatMap { @@ -174,7 +179,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { if (length < stringLitLen) { Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit)) } else if (length > stringLitLen) { - Some(Seq(expr, StringRPad(lit, Literal(length)))) + val paddedExpr = if (padCharCol) { + StringRPad(expr, Literal(length)) + } else { + expr + } + Some(Seq(paddedExpr, StringRPad(lit, Literal(length)))) + } else if (padCharCol) { + Some(Seq(StringRPad(expr, Literal(length)), lit)) } else { None } @@ -186,7 +198,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { - if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr + private def addPadding( + expr: Expression, + charLength: Int, + targetLength: Int, + alwaysPad: Boolean): Expression = { + if (targetLength > charLength) { + StringRPad(expr, Literal(targetLength)) + } else if (alwaysPad) { + StringRPad(expr, Literal(charLength)) + } else expr } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 94dd3bc0bd63e..2e24087d507bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -722,7 +722,7 @@ object DataSource extends Logging { val qualifiedPaths = pathStrings.map { pathString => val path = new Path(pathString) val fs = path.getFileSystem(hadoopConf) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) + fs.makeQualified(path) } // Split the paths into glob and non glob paths, because we don't need to do an existence check diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 94c2d2ffaca59..431480bb2edf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -27,28 +27,27 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, SQLConfHelper} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow, QualifiedTableName, SQLConfHelper} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder} -import org.apache.spark.sql.connector.catalog.SupportsRead +import org.apache.spark.sql.connector.catalog.{SupportsRead, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, PushedDownOperators} import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -242,7 +241,8 @@ object DataSourceAnalysis extends Rule[LogicalPlan] { class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { private def readDataSourceTable( table: CatalogTable, extraOptions: CaseInsensitiveStringMap): LogicalPlan = { - val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) + val qualifiedTableName = + QualifiedTableName(table.identifier.catalog.get, table.database, table.identifier.table) val catalog = sparkSession.sessionState.catalog val dsOptions = DataSourceUtils.generateDatasourceOptions(extraOptions, table) catalog.getCachedPlan(qualifiedTableName, () => { @@ -284,6 +284,13 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] _, _, _, _, _, _) => i.copy(table = DDLUtils.readHiveTable(tableMeta)) + case append @ AppendData( + DataSourceV2Relation( + V1Table(table: CatalogTable), _, _, _, _), _, _, _, _, _) if !append.isByName => + InsertIntoStatement(UnresolvedCatalogRelation(table), + table.partitionColumnNames.map(name => name -> None).toMap, + Seq.empty, append.query, false, append.isByName) + case UnresolvedCatalogRelation(tableMeta, options, false) if DDLUtils.isDatasourceTable(tableMeta) => readDataSourceTable(tableMeta, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index e4bf24ad88d1e..9fe42c6bcf2bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -210,9 +210,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects val requiredAttributes = AttributeSet(requiredExpressions) - val readDataColumns = dataColumns + val readDataColumns = dataColumnsWithoutPartitionCols .filter(requiredAttributes.contains) - .filterNot(partitionColumns.contains) // Metadata attributes are part of a column of type struct up to this point. Here we extract // this column from the schema and specify a matcher for that. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 1dffea4e1bc87..d5923a577daac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -63,7 +63,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _)) if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty => val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), + filters.filter { f => + f.deterministic && + !SubqueryExpression.hasSubquery(f) && + // Python UDFs might exist because this rule is applied before ``ExtractPythonUDFs``. + !f.exists(_.isInstanceOf[PythonUDF]) + }, logicalRelation.output) val (partitionKeyFilters, _) = DataSourceUtils .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala index 35d9b5d60348d..cf0e67ecc30fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala @@ -64,6 +64,7 @@ object SchemaMergeUtils extends Logging { val ignoreCorruptFiles = new FileSourceOptions(CaseInsensitiveMap(parameters)).ignoreCorruptFiles + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis // Issues a Spark job to read Parquet/ORC schema in parallel. val partiallyMergedSchemas = @@ -84,7 +85,7 @@ object SchemaMergeUtils extends Logging { var mergedSchema = schemas.head schemas.tail.foreach { schema => try { - mergedSchema = mergedSchema.merge(schema) + mergedSchema = mergedSchema.merge(schema, caseSensitive) } catch { case cause: SparkException => throw QueryExecutionErrors.failedMergingSchemaError(mergedSchema, schema, cause) } @@ -99,7 +100,7 @@ object SchemaMergeUtils extends Logging { var finalSchema = partiallyMergedSchemas.head partiallyMergedSchemas.tail.foreach { schema => try { - finalSchema = finalSchema.merge(schema) + finalSchema = finalSchema.merge(schema, caseSensitive) } catch { case cause: SparkException => throw QueryExecutionErrors.failedMergingSchemaError(finalSchema, schema, cause) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 069ad9562a7d5..0ff96f073f03b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -100,12 +100,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val columnPruning = sparkSession.sessionState.conf.csvColumnPruning val parsedOptions = new CSVOptions( options, - columnPruning, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val isColumnPruningEnabled = parsedOptions.isColumnPruningEnabled // Check a field requirement for corrupt records here to throw an exception in a driver side ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) @@ -125,7 +125,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { actualRequiredSchema, parsedOptions, actualFilters) - val schema = if (columnPruning) actualRequiredSchema else actualDataSchema + val schema = if (isColumnPruningEnabled) actualRequiredSchema else actualDataSchema val isStartOfFile = file.start == 0 val headerChecker = new CSVHeaderChecker( schema, parsedOptions, source = s"CSV file: ${file.urlEncodedPath}", isStartOfFile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 268a65b81ff68..57651684070f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -239,6 +239,14 @@ class JDBCOptions( .get(JDBC_PREFER_TIMESTAMP_NTZ) .map(_.toBoolean) .getOrElse(SQLConf.get.timestampType == TimestampNTZType) + + override def hashCode: Int = this.parameters.hashCode() + + override def equals(other: Any): Boolean = other match { + case otherOption: JDBCOptions => + otherOption.parameters.equals(this.parameters) + case _ => false + } } class JdbcOptionsInWrite( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b7019c1dcbe53..7b5c4cfc9b6e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -269,6 +269,7 @@ object JdbcUtils extends Logging with SQLConfHelper { val fields = new Array[StructField](ncols) var i = 0 while (i < ncols) { + val metadata = new MetadataBuilder() val columnName = rsmd.getColumnLabel(i + 1) val dataType = rsmd.getColumnType(i + 1) val typeName = rsmd.getColumnTypeName(i + 1) @@ -289,8 +290,6 @@ object JdbcUtils extends Logging with SQLConfHelper { } else { rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls } - val metadata = new MetadataBuilder() - metadata.putLong("scale", fieldScale) dataType match { case java.sql.Types.TIME => @@ -302,7 +301,8 @@ object JdbcUtils extends Logging with SQLConfHelper { metadata.putBoolean("rowid", true) case _ => } - + metadata.putBoolean("isSigned", isSigned) + metadata.putLong("scale", fieldScale) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, typeName, fieldSize, fieldScale, isSigned, isTimestampNTZ)) @@ -430,14 +430,16 @@ object JdbcUtils extends Logging with SQLConfHelper { case LongType if metadata.contains("binarylong") => (rs: ResultSet, row: InternalRow, pos: Int) => - val bytes = rs.getBytes(pos + 1) - var ans = 0L - var j = 0 - while (j < bytes.length) { - ans = 256 * ans + (255 & bytes(j)) - j = j + 1 - } - row.setLong(pos, ans) + val l = nullSafeConvert[Array[Byte]](rs.getBytes(pos + 1), bytes => { + var ans = 0L + var j = 0 + while (j < bytes.length) { + ans = 256 * ans + (255 & bytes(j)) + j = j + 1 + } + ans + }) + row.update(pos, l) case LongType => (rs: ResultSet, row: InternalRow, pos: Int) => @@ -898,7 +900,7 @@ object JdbcUtils extends Logging with SQLConfHelper { case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) case _ => df } - repartitionedDF.rdd.foreachPartition { iterator => savePartition( + repartitionedDF.foreachPartition { iterator: Iterator[Row] => savePartition( table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b7e6f11f67d69..53d2b08431f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -31,6 +31,7 @@ import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce._ import org.apache.spark.TaskContext +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -152,6 +153,12 @@ class OrcFileFormat assert(supportBatch(sparkSession, resultSchema)) } + val memoryMode = if (sqlConf.offHeapColumnVectorEnabled) { + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, sqlConf.caseSensitiveAnalysis) val broadcastedConf = @@ -196,7 +203,7 @@ class OrcFileFormat val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) if (enableVectorizedReader) { - val batchReader = new OrcColumnarBatchReader(capacity) + val batchReader = new OrcColumnarBatchReader(capacity, memoryMode) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 5899b6621ad8e..0983841dc8c2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets.UTF_8 import java.sql.{Date, Timestamp} @@ -612,7 +612,13 @@ class ParquetFilters( value == null || (nameToParquetField(name).fieldType match { case ParquetBooleanType => value.isInstanceOf[JBoolean] case ParquetIntegerType if value.isInstanceOf[Period] => true - case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value match { + // Byte/Short/Int are all stored as INT32 in Parquet so filters are built using type Int. + // We don't create a filter if the value would overflow. + case _: JByte | _: JShort | _: Integer => true + case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue + case _ => false + } case ParquetLongType => value.isInstanceOf[JLong] || value.isInstanceOf[Duration] case ParquetFloatType => value.isInstanceOf[JFloat] case ParquetDoubleType => value.isInstanceOf[JDouble] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 023d2460959cd..95869b6fbb9d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -32,7 +33,7 @@ import org.apache.spark.sql.internal.SQLConf class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) - extends FileSourceOptions(parameters) { + extends FileSourceOptions(parameters) with Logging { import ParquetOptions._ @@ -59,6 +60,9 @@ class ParquetOptions( throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } + if (codecName == "lz4raw") { + log.warn("Parquet compression codec 'lz4raw' is deprecated, please use 'lz4_raw'") + } shortParquetCompressionCodecNames(codecName).name() } @@ -96,7 +100,9 @@ object ParquetOptions extends DataSourceOptions { "lzo" -> CompressionCodecName.LZO, "brotli" -> CompressionCodecName.BROTLI, "lz4" -> CompressionCodecName.LZ4, + // Deprecated, to be removed at Spark 4.0.0, please use 'lz4_raw' instead. "lz4raw" -> CompressionCodecName.LZ4_RAW, + "lz4_raw" -> CompressionCodecName.LZ4_RAW, "zstd" -> CompressionCodecName.ZSTD) def getParquetCompressionCodecName(name: String): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index e257be3d189aa..f534669d58c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -509,11 +509,10 @@ private[parquet] class ParquetRowConverter( // can be read as Spark's TimestampNTZ type. This is to avoid mistakes in reading the timestamp // values. private def canReadAsTimestampNTZ(parquetType: Type): Boolean = - schemaConverter.isTimestampNTZEnabled() && - parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 && - parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] && - !parquetType.getLogicalTypeAnnotation - .asInstanceOf[TimestampLogicalTypeAnnotation].isAdjustedToUTC + parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 && + parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] && + !parquetType.getLogicalTypeAnnotation + .asInstanceOf[TimestampLogicalTypeAnnotation].isAdjustedToUTC /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 9c9e7ce729c1b..a78b96ae6fcc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -72,13 +72,6 @@ class ParquetToSparkSchemaConverter( inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean, nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean) - /** - * Returns true if TIMESTAMP_NTZ type is enabled in this ParquetToSparkSchemaConverter. - */ - def isTimestampNTZEnabled(): Boolean = { - inferTimestampNTZ - } - /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index eba3c71f871e3..2a3a5cdeb82b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -101,7 +101,7 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions).get.map(_._2) + groupPartitions(newPartitions).getOrElse(Seq.empty).map(_._2) case _ => // no validation is needed as the data source did not report any specific partitioning @@ -145,7 +145,7 @@ case class BatchScanExec( "is enabled") val groupedPartitions = groupPartitions(finalPartitions.map(_.head), - groupSplits = true).get + groupSplits = true).getOrElse(Seq.empty) // This means the input partitions are not grouped by partition values. We'll need to // check `groupByPartitionValues` and decide whether to group and replicate splits diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 0106a9c5aea0e..d46c5116e6151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.commons.lang3.StringUtils +import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.internal.Logging @@ -69,6 +70,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } } + private def hadoopConf = session.sessionState.newHadoopConf() + private def refreshCache(r: DataSourceV2Relation)(): Unit = { session.sharedState.cacheManager.recacheByPlan(session, r) } @@ -103,7 +106,19 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } private def qualifyLocInTableSpec(tableSpec: TableSpec): TableSpec = { - tableSpec.withNewLocation(tableSpec.location.map(makeQualifiedDBObjectPath(_))) + val newLoc = tableSpec.location.map { loc => + val locationUri = CatalogUtils.stringToURI(loc) + val qualified = if (locationUri.isAbsolute) { + locationUri + } else if (new Path(locationUri).isAbsolute) { + CatalogUtils.makeQualifiedPath(locationUri, hadoopConf) + } else { + // Leave it to the catalog implementation to qualify relative paths. + locationUri + } + CatalogUtils.URIToString(qualified) + } + tableSpec.withNewLocation(newLoc) } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala index 61ccda3fc9543..2683d8d547f00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala @@ -53,7 +53,7 @@ case class DescribeColumnExec( read.newScanBuilder(CaseInsensitiveStringMap.empty()).build() match { case s: SupportsReportStatistics => val stats = s.estimateStatistics() - Some(stats.columnStats().get(FieldReference.column(column.name))) + Option(stats.columnStats().get(FieldReference.column(column.name))) case _ => None } case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 447a36fe622c9..7e0bc25a9a1e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.{sources, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF, SubqueryExpression} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} @@ -73,7 +73,10 @@ abstract class FileScanBuilder( val (deterministicFilters, nonDeterminsticFilters) = filters.partition(_.deterministic) val (partitionFilters, dataFilters) = DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, deterministicFilters) - this.partitionFilters = partitionFilters + this.partitionFilters = partitionFilters.filter { f => + // Python UDFs might exist because this rule is applied before ``ExtractPythonUDFs``. + !SubqueryExpression.hasSubquery(f) && !f.exists(_.isInstanceOf[PythonUDF]) + } this.dataFilters = dataFilters val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] for (filterExpr <- dataFilters) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index 4b1a099d3bac9..f18424b4bcb86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -38,7 +38,7 @@ case class FileWriterFactory ( @transient private lazy val jobId = SparkHadoopWriterUtils.createJobID(jobTrackerID, 0) override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = { - val taskAttemptContext = createTaskAttemptContext(partitionId) + val taskAttemptContext = createTaskAttemptContext(partitionId, realTaskId.toInt & Int.MaxValue) committer.setupTask(taskAttemptContext) if (description.partitionColumns.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) @@ -47,9 +47,11 @@ case class FileWriterFactory ( } } - private def createTaskAttemptContext(partitionId: Int): TaskAttemptContextImpl = { + private def createTaskAttemptContext( + partitionId: Int, + realTaskId: Int): TaskAttemptContextImpl = { val taskId = new TaskID(jobId, TaskType.MAP, partitionId) - val taskAttemptId = new TaskAttemptID(taskId, 0) + val taskAttemptId = new TaskAttemptID(taskId, realTaskId) // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala index 6fa51ed63bd46..64f76f59c286f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCreateTableExec.scala @@ -120,7 +120,7 @@ case class ShowCreateTableExec( private def showTableLocation(table: Table, builder: StringBuilder): Unit = { val isManagedOption = Option(table.properties.get(TableCatalog.PROP_IS_MANAGED_LOCATION)) // Only generate LOCATION clause if it's not managed. - if (isManagedOption.forall(_.equalsIgnoreCase("false"))) { + if (isManagedOption.isEmpty || !isManagedOption.get.equalsIgnoreCase("true")) { Option(table.properties.get(TableCatalog.PROP_LOCATION)) .map("LOCATION '" + escapeSingleQuotedString(_) + "'\n") .foreach(builder.append) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index a7062a9a596c3..a022a01455a09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -121,7 +121,9 @@ class V2SessionCatalog(catalog: SessionCatalog) val storage = DataSource.buildStorageFormatFromOptions(toOptions(tableProperties.toMap)) .copy(locationUri = location.map(CatalogUtils.stringToURI)) val isExternal = properties.containsKey(TableCatalog.PROP_EXTERNAL) - val tableType = if (isExternal || location.isDefined) { + val isManagedLocation = Option(properties.get(TableCatalog.PROP_IS_MANAGED_LOCATION)) + .exists(_.equalsIgnoreCase("true")) + val tableType = if (isExternal || (location.isDefined && !isManagedLocation)) { CatalogTableType.EXTERNAL } else { CatalogTableType.MANAGED @@ -146,7 +148,7 @@ class V2SessionCatalog(catalog: SessionCatalog) throw QueryCompilationErrors.tableAlreadyExistsError(ident) } - loadTable(ident) + null // Return null to save the `loadTable` call for CREATE TABLE without AS SELECT. } private def toOptions(properties: Map[String, String]): Map[String, String] = { @@ -187,7 +189,7 @@ class V2SessionCatalog(catalog: SessionCatalog) throw QueryCompilationErrors.noSuchTableError(ident) } - loadTable(ident) + null // Return null to save the `loadTable` call for ALTER TABLE. } override def purgeTable(ident: Identifier): Boolean = { @@ -231,8 +233,6 @@ class V2SessionCatalog(catalog: SessionCatalog) throw QueryCompilationErrors.tableAlreadyExistsError(newIdent) } - // Load table to make sure the table exists - loadTable(oldIdent) catalog.renameTable(oldIdent.asTableIdentifier, newIdent.asTableIdentifier) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 4a9b85450a176..c99e2bba2e960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage} @@ -81,9 +81,10 @@ case class CreateTableAsSelectExec( } throw QueryCompilationErrors.tableAlreadyExistsError(ident) } - val table = catalog.createTable( + val table = Option(catalog.createTable( ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), partitioning.toArray, properties.asJava) + ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) writeToTable(catalog, table, writeOptions, ident, query) } } @@ -115,9 +116,10 @@ case class AtomicCreateTableAsSelectExec( } throw QueryCompilationErrors.tableAlreadyExistsError(ident) } - val stagedTable = catalog.stageCreate( + val stagedTable = Option(catalog.stageCreate( ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), partitioning.toArray, properties.asJava) + ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) writeToTable(catalog, stagedTable, writeOptions, ident, query) } } @@ -161,9 +163,10 @@ case class ReplaceTableAsSelectExec( } else if (!orCreate) { throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } - val table = catalog.createTable( + val table = Option(catalog.createTable( ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), partitioning.toArray, properties.asJava) + ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) writeToTable(catalog, table, writeOptions, ident, query) } } @@ -213,7 +216,9 @@ case class AtomicReplaceTableAsSelectExec( } else { throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } - writeToTable(catalog, staged, writeOptions, ident, query) + val table = Option(staged).getOrElse( + catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) + writeToTable(catalog, table, writeOptions, ident, query) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala index 37f6ae4aaa9fc..cef5a71ca9c60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -58,7 +58,7 @@ case class CSVPartitionReaderFactory( actualReadDataSchema, options, filters) - val schema = if (options.columnPruning) actualReadDataSchema else actualDataSchema + val schema = if (options.isColumnPruningEnabled) actualReadDataSchema else actualDataSchema val isStartOfFile = file.start == 0 val headerChecker = new CSVHeaderChecker( schema, options, source = s"CSV file: ${file.urlEncodedPath}", isStartOfFile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 2b7bdae6b31b4..b23071e50cbed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -26,6 +26,7 @@ import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} @@ -57,7 +58,8 @@ case class OrcPartitionReaderFactory( partitionSchema: StructType, filters: Array[Filter], aggregation: Option[Aggregation], - options: OrcOptions) extends FilePartitionReaderFactory { + options: OrcOptions, + memoryMode: MemoryMode) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -146,7 +148,7 @@ case class OrcPartitionReaderFactory( val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) - val batchReader = new OrcColumnarBatchReader(capacity) + val batchReader = new OrcColumnarBatchReader(capacity, memoryMode) batchReader.initialize(fileSplit, taskAttemptContext) val requestedPartitionColIds = Array.fill(readDataSchema.length)(-1) ++ Range(0, partitionSchema.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 072ab26774e52..ca37d22eeb1e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation @@ -64,11 +65,16 @@ case class OrcScan( override def createReaderFactory(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( new SerializableConfiguration(hadoopConf)) + val memoryMode = if (sparkSession.sessionState.conf.offHeapColumnVectorEnabled) { + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate, - new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf)) + new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf), memoryMode) } override def equals(obj: Any): Boolean = obj match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index 7360349284ec1..479e9065c0712 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -50,7 +50,8 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla // apply special dynamic filtering only for group-based row-level operations case GroupBasedRowLevelOperation(replaceData, _, Some(cond), DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) - if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral => + if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral + && scan.filterAttributes().nonEmpty => // use reference equality on scan to find required scan relations val newQuery = replaceData.query transformUp { @@ -115,6 +116,7 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla matchingRowsPlan: LogicalPlan, buildKeys: Seq[Attribute], pruningKeys: Seq[Attribute]): Expression = { + assert(buildKeys.nonEmpty && pruningKeys.nonEmpty) val buildQuery = Aggregate(buildKeys, buildKeys, matchingRowsPlan) DynamicPruningExpression( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 42c880e7c6262..ee0ea11816f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -550,7 +550,7 @@ case class EnsureRequirements( private def createKeyGroupedShuffleSpec( partitioning: Partitioning, distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { - def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { + def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) val clustering = distribution.clustering @@ -570,11 +570,10 @@ case class EnsureRequirements( } partitioning match { - case p: KeyGroupedPartitioning => check(p) + case p: KeyGroupedPartitioning => tryCreate(p) case PartitioningCollection(partitionings) => val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution)) - assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined)) - specs.head + specs.filter(_.isDefined).map(_.get).headOption case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 9f9f874314639..b82cee2c0fbe7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -73,7 +73,7 @@ case class BroadcastHashJoinExec( joinType match { case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { - case h: HashPartitioning => expandOutputPartitioning(h) + case h: HashPartitioningLike => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) case other => other } @@ -99,7 +99,7 @@ case class BroadcastHashJoinExec( private def expandOutputPartitioning( partitioning: PartitioningCollection): PartitioningCollection = { PartitioningCollection(partitioning.partitionings.flatMap { - case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) case other => Seq(other) }) @@ -111,11 +111,12 @@ case class BroadcastHashJoinExec( // the expanded partitioning will have the following expressions: // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. - private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { + private def expandOutputPartitioning( + partitioning: HashPartitioningLike): PartitioningCollection = { PartitioningCollection(partitioning.multiTransformDown { case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) => e +: streamedKeyToBuildKeyMapping(e.canonicalized) - }.asInstanceOf[Stream[HashPartitioning]] + }.asInstanceOf[Stream[HashPartitioningLike]] .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7c48baf99ef83..07f7915416c1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -138,9 +138,8 @@ trait HashJoin extends JoinCodegenSupport { UnsafeProjection.create(streamedBoundKeys) @transient protected[this] lazy val boundCondition = if (condition.isDefined) { - if (joinType == FullOuter && buildSide == BuildLeft) { - // Put join left side before right side. This is to be consistent with - // `ShuffledHashJoinExec.fullOuterJoin`. + if ((joinType == FullOuter || joinType == LeftOuter) && buildSide == BuildLeft) { + // Put join left side before right side. Predicate.create(condition.get, buildPlan.output ++ streamedPlan.output).eval _ } else { Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 877f6508d963f..77135d21a26ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -282,7 +282,7 @@ case class TakeOrderedAndProjectExec( projectList.map(_.toAttribute) } - override def executeCollect(): Array[InternalRow] = { + override def executeCollect(): Array[InternalRow] = executeQuery { val orderingSatisfies = SortOrder.orderingSatisfies(child.outputOrdering, sortOrder) val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val limited = if (orderingSatisfies) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 3326c5d4cb994..09406345ed770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -90,6 +90,13 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato AccumulableInfo(id, name, internOption(update), internOption(value), true, true, SQLMetrics.cachedSQLAccumIdentifier) } + + // We should provide the raw value which can be -1, so that `SQLMetrics.stringValue` can correctly + // filter out the invalid -1 values. + override def toInfoUpdate: AccumulableInfo = { + AccumulableInfo(id, name, internOption(Some(_value)), None, true, true, + SQLMetrics.cachedSQLAccumIdentifier) + } } object SQLMetrics { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala index ad3212871fc94..677e2fccb6b48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -65,10 +65,10 @@ trait CheckpointFileManager { /** Open a file for reading, or throw exception if it does not exist. */ def open(path: Path): FSDataInputStream - /** List the files in a path that match a filter. */ + /** List the files/directories in a path that match a filter. */ def list(path: Path, filter: PathFilter): Array[FileStatus] - /** List all the files in a path. */ + /** List all the files/directories in a path. */ def list(path: Path): Array[FileStatus] = { list(path, (_: Path) => true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 04a1de02ea587..23855db9d7f5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -48,8 +48,8 @@ object FileStreamSink extends Logging { path match { case Seq(singlePath) => - val hdfsPath = new Path(singlePath) try { + val hdfsPath = new Path(singlePath) val fs = hdfsPath.getFileSystem(hadoopConf) if (fs.isDirectory(hdfsPath)) { val metadataPath = getMetadataLogPath(fs, hdfsPath, sqlConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 2b0172bb9555c..9a811db679d01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -325,6 +325,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** List the available batches on file system. */ protected def listBatches: Array[Long] = { val batchIds = fileManager.list(metadataPath, batchFilesFilter) + // Batches must be files + .filter(f => f.isFile) .map(f => pathToBatchId(f.getPath)) ++ // Iterate over keySet is not thread safe. We call `toArray` to make a copy in the lock to // elimiate the race condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 913805d1a074d..cea7ec432aad9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -98,7 +98,9 @@ object OffsetSeqMetadata extends Logging { SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC, - STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION) + STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION, + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN + ) /** * Default values of relevant configurations that are used for backward compatibility. @@ -119,7 +121,8 @@ object OffsetSeqMetadata extends Logging { STREAMING_JOIN_STATE_FORMAT_VERSION.key -> SymmetricHashJoinStateManager.legacyVersion.toString, STATE_STORE_COMPRESSION_CODEC.key -> "lz4", - STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false" + STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false", + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true" ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 3ad1dc58cae79..15e27edcd0410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -631,13 +631,38 @@ case class StreamingSymmetricHashJoinExec( private val iteratorNotEmpty: Boolean = super.hasNext override def completion(): Unit = { - val isLeftSemiWithMatch = - joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty - // Add to state store only if both removal predicates do not match, - // and the row is not matched for left side of left semi join. - val shouldAddToState = - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) && - !isLeftSemiWithMatch + // The criteria of whether the input has to be added into state store or not: + // - Left side: input can be skipped to be added to the state store if it's already matched + // and the join type is left semi. + // For other cases, the input should be added, including the case it's going to be evicted + // in this batch. It hasn't yet evaluated with inputs from right side for this batch. + // Refer to the classdoc of SteramingSymmetricHashJoinExec about how stream-stream join + // works. + // - Right side: for this side, the evaluation with inputs from left side for this batch + // is done at this point. That said, input can be skipped to be added to the state store + // if input is going to be evicted in this batch. Though, input should be added to the + // state store if it's right outer join or full outer join, as unmatched output is + // handled during state eviction. + val isLeftSemiWithMatch = joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty + val shouldAddToState = if (isLeftSemiWithMatch) { + false + } else if (joinSide == LeftSide) { + true + } else { + // joinSide == RightSide + + // if the input is not evicted in this batch (hence need to be persisted) + val isNotEvictingInThisBatch = + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + + isNotEvictingInThisBatch || + // if the input is producing "unmatched row" in this batch + ( + (joinType == RightOuter && !iteratorNotEmpty) || + (joinType == FullOuter && !iteratorNotEmpty) + ) + } + if (shouldAddToState) { joinStateManager.append(key, thisRow, matched = iteratorNotEmpty) updatedStateRowsCount += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala index 9a0b6a733d051..914d116e27fc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala @@ -31,7 +31,6 @@ trait HDFSBackedStateStoreMap { def remove(key: UnsafeRow): UnsafeRow def iterator(): Iterator[UnsafeRowPair] def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] - def clear(): Unit } object HDFSBackedStateStoreMap { @@ -79,8 +78,6 @@ class NoPrefixHDFSBackedStateStoreMap extends HDFSBackedStateStoreMap { override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = { throw new UnsupportedOperationException("Prefix scan is not supported!") } - - override def clear(): Unit = map.clear() } class PrefixScannableHDFSBackedStateStoreMap( @@ -169,9 +166,4 @@ class PrefixScannableHDFSBackedStateStoreMap( .iterator .map { key => unsafeRowPair.withRows(key, map.get(key)) } } - - override def clear(): Unit = { - map.clear() - prefixKeyToKeysMap.clear() - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index fbf4b357a35f9..85de3e7ff9425 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -262,7 +262,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } override def close(): Unit = { - loadedMaps.values.asScala.foreach(_.clear()) + // Clearing the map resets the TreeMap.root to null, and therefore entries inside the + // `loadedMaps` will be de-referenced and GCed automatically when their reference + // counts become 0. + synchronized { loadedMaps.clear() } } override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 2398b7780726a..6c0447e1a325a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.util.Locale +import java.util.concurrent.TimeUnit import javax.annotation.concurrent.GuardedBy import scala.collection.{mutable, Map} +import scala.collection.mutable.ListBuffer import scala.ref.WeakReference import scala.util.Try @@ -56,7 +58,11 @@ class RocksDB( hadoopConf: Configuration = new Configuration, loggingId: String = "") extends Logging { - case class RocksDBSnapshot(checkpointDir: File, version: Long, numKeys: Long) { + case class RocksDBSnapshot( + checkpointDir: File, + version: Long, + numKeys: Long, + capturedFileMappings: RocksDBFileMappings) { def close(): Unit = { silentDeleteRecursively(checkpointDir, s"Free up local checkpoint of snapshot $version") } @@ -64,6 +70,7 @@ class RocksDB( @volatile private var latestSnapshot: Option[RocksDBSnapshot] = None @volatile private var lastSnapshotVersion = 0L + private val oldSnapshots = new ListBuffer[RocksDBSnapshot] RocksDBLoader.loadLibrary() @@ -147,22 +154,30 @@ class RocksDB( try { if (loadedVersion != version) { closeDB() + // deep copy is needed to avoid race condition + // between maintenance and task threads + fileManager.copyFileMapping() val latestSnapshotVersion = fileManager.getLatestSnapshotVersion(version) val metadata = fileManager.loadCheckpointFromDfs(latestSnapshotVersion, workingDir) loadedVersion = latestSnapshotVersion + // reset last snapshot version + if (lastSnapshotVersion > latestSnapshotVersion) { + // discard any newer snapshots + lastSnapshotVersion = 0L + } openDB() numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) { - // we don't track the total number of rows - discard the number being track - -1L - } else if (metadata.numKeys < 0) { - // we track the total number of rows, but the snapshot doesn't have tracking number - // need to count keys now - countKeys() - } else { - metadata.numKeys - } + // we don't track the total number of rows - discard the number being track + -1L + } else if (metadata.numKeys < 0) { + // we track the total number of rows, but the snapshot doesn't have tracking number + // need to count keys now + countKeys() + } else { + metadata.numKeys + } if (loadedVersion != version) replayChangelog(version) // After changelog replay the numKeysOnWritingVersion will be updated to // the correct number of keys in the loaded version. @@ -191,6 +206,7 @@ class RocksDB( */ private def replayChangelog(endVersion: Long): Unit = { for (v <- loadedVersion + 1 to endVersion) { + logInfo(s"replaying changelog from version $loadedVersion -> $endVersion") var changelogReader: StateStoreChangelogReader = null try { changelogReader = fileManager.getChangelogReader(v) @@ -356,14 +372,19 @@ class RocksDB( // background operations. val cp = Checkpoint.create(db) cp.createCheckpoint(checkpointDir.toString) + // if changelog checkpointing is disabled, the snapshot is uploaded synchronously + // inside the uploadSnapshot() called below. + // If changelog checkpointing is enabled, snapshot will be uploaded asynchronously + // during state store maintenance. synchronized { - // if changelog checkpointing is disabled, the snapshot is uploaded synchronously - // inside the uploadSnapshot() called below. - // If changelog checkpointing is enabled, snapshot will be uploaded asynchronously - // during state store maintenance. - latestSnapshot.foreach(_.close()) + if (latestSnapshot.isDefined) { + oldSnapshots += latestSnapshot.get + } latestSnapshot = Some( - RocksDBSnapshot(checkpointDir, newVersion, numKeysOnWritingVersion)) + RocksDBSnapshot(checkpointDir, + newVersion, + numKeysOnWritingVersion, + fileManager.captureFileMapReference())) lastSnapshotVersion = newVersion } } @@ -415,22 +436,34 @@ class RocksDB( } private def uploadSnapshot(): Unit = { + var oldSnapshotsImmutable: List[RocksDBSnapshot] = Nil val localCheckpoint = synchronized { val checkpoint = latestSnapshot latestSnapshot = None + + // Convert mutable list buffer to immutable to prevent + // race condition with commit where old snapshot is added + oldSnapshotsImmutable = oldSnapshots.toList + oldSnapshots.clear() + checkpoint } localCheckpoint match { - case Some(RocksDBSnapshot(localDir, version, numKeys)) => + case Some(RocksDBSnapshot(localDir, version, numKeys, capturedFileMappings)) => try { val uploadTime = timeTakenMs { - fileManager.saveCheckpointToDfs(localDir, version, numKeys) + fileManager.saveCheckpointToDfs(localDir, version, numKeys, capturedFileMappings) fileManagerMetrics = fileManager.latestSaveCheckpointMetrics } logInfo(s"$loggingId: Upload snapshot of version $version," + s" time taken: $uploadTime ms") } finally { localCheckpoint.foreach(_.close()) + + // Clean up old latestSnapshots + for (snapshot <- oldSnapshotsImmutable) { + snapshot.close() + } } case _ => } @@ -546,8 +579,11 @@ class RocksDB( private def acquire(): Unit = acquireLock.synchronized { val newAcquiredThreadInfo = AcquiredThreadInfo() - val waitStartTime = System.currentTimeMillis - def timeWaitedMs = System.currentTimeMillis - waitStartTime + val waitStartTime = System.nanoTime() + def timeWaitedMs = { + val elapsedNanos = System.nanoTime() - waitStartTime + TimeUnit.MILLISECONDS.convert(elapsedNanos, TimeUnit.NANOSECONDS) + } def isAcquiredByDifferentThread = acquiredThreadInfo != null && acquiredThreadInfo.threadRef.get.isDefined && newAcquiredThreadInfo.threadRef.get.get.getId != acquiredThreadInfo.threadRef.get.get.getId diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index 0891d7737135a..b4fe3e22e8882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -131,7 +131,6 @@ class RocksDBFileManager( import RocksDBImmutableFile._ - private val versionToRocksDBFiles = new ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]] private lazy val fm = CheckpointFileManager.create(new Path(dfsRootDir), hadoopConf) private val fs = new Path(dfsRootDir).getFileSystem(hadoopConf) private val onlyZipFiles = new PathFilter { @@ -145,6 +144,30 @@ class RocksDBFileManager( private def codec = CompressionCodec.createCodec(sparkConf, codecName) + @volatile private var fileMappings = RocksDBFileMappings( + new ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]], + new ConcurrentHashMap[String, RocksDBImmutableFile] + ) + + /** + * Make a deep copy of versionToRocksDBFiles and localFilesToDfsFiles to avoid + * current task thread from overwriting the file mapping whenever background maintenance + * thread attempts to upload a snapshot + */ + def copyFileMapping() : Unit = { + val newVersionToRocksDBFiles = new ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]] + val newLocalFilesToDfsFiles = new ConcurrentHashMap[String, RocksDBImmutableFile] + + newVersionToRocksDBFiles.putAll(fileMappings.versionToRocksDBFiles) + newLocalFilesToDfsFiles.putAll(fileMappings.localFilesToDfsFiles) + + fileMappings = RocksDBFileMappings(newVersionToRocksDBFiles, newLocalFilesToDfsFiles) + } + + def captureFileMapReference(): RocksDBFileMappings = { + fileMappings + } + def getChangeLogWriter(version: Long): StateStoreChangelogWriter = { val rootDir = new Path(dfsRootDir) val changelogFile = dfsChangelogFile(version) @@ -176,10 +199,14 @@ class RocksDBFileManager( def latestSaveCheckpointMetrics: RocksDBFileManagerMetrics = saveCheckpointMetrics /** Save all the files in given local checkpoint directory as a committed version in DFS */ - def saveCheckpointToDfs(checkpointDir: File, version: Long, numKeys: Long): Unit = { + def saveCheckpointToDfs( + checkpointDir: File, + version: Long, + numKeys: Long, + capturedFileMappings: RocksDBFileMappings): Unit = { logFilesInDir(checkpointDir, s"Saving checkpoint files for version $version") val (localImmutableFiles, localOtherFiles) = listRocksDBFiles(checkpointDir) - val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles) + val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles, capturedFileMappings) val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys) val metadataFile = localMetadataFile(checkpointDir) metadata.writeToFile(metadataFile) @@ -207,8 +234,13 @@ class RocksDBFileManager( */ def loadCheckpointFromDfs(version: Long, localDir: File): RocksDBCheckpointMetadata = { logInfo(s"Loading checkpoint files for version $version") + // The unique ids of SST files are checked when opening a rocksdb instance. The SST files + // in larger versions can't be reused even if they have the same size and name because + // they belong to another rocksdb instance. + fileMappings.versionToRocksDBFiles.keySet().removeIf(_ >= version) val metadata = if (version == 0) { if (localDir.exists) Utils.deleteRecursively(localDir) + fileMappings.localFilesToDfsFiles.clear() localDir.mkdirs() RocksDBCheckpointMetadata(Seq.empty, 0) } else { @@ -221,7 +253,7 @@ class RocksDBFileManager( val metadata = RocksDBCheckpointMetadata.readFromFile(metadataFile) logInfo(s"Read metadata for version $version:\n${metadata.prettyJson}") loadImmutableFilesFromDfs(metadata.immutableFiles, localDir) - versionToRocksDBFiles.put(version, metadata.immutableFiles) + fileMappings.versionToRocksDBFiles.put(version, metadata.immutableFiles) metadataFile.delete() metadata } @@ -375,9 +407,9 @@ class RocksDBFileManager( // Resolve RocksDB files for all the versions and find the max version each file is used val fileToMaxUsedVersion = new mutable.HashMap[String, Long] sortedSnapshotVersions.foreach { version => - val files = Option(versionToRocksDBFiles.get(version)).getOrElse { + val files = Option(fileMappings.versionToRocksDBFiles.get(version)).getOrElse { val newResolvedFiles = getImmutableFilesFromVersionZip(version) - versionToRocksDBFiles.put(version, newResolvedFiles) + fileMappings.versionToRocksDBFiles.put(version, newResolvedFiles) newResolvedFiles } files.foreach(f => fileToMaxUsedVersion(f.dfsFileName) = @@ -422,7 +454,7 @@ class RocksDBFileManager( val versionFile = dfsBatchZipFile(version) try { fm.delete(versionFile) - versionToRocksDBFiles.remove(version) + fileMappings.versionToRocksDBFiles.remove(version) logDebug(s"Deleted version $version") } catch { case e: Exception => @@ -441,47 +473,55 @@ class RocksDBFileManager( /** Save immutable files to DFS directory */ private def saveImmutableFilesToDfs( version: Long, - localFiles: Seq[File]): Seq[RocksDBImmutableFile] = { + localFiles: Seq[File], + capturedFileMappings: RocksDBFileMappings): Seq[RocksDBImmutableFile] = { // Get the immutable files used in previous versions, as some of those uploaded files can be // reused for this version logInfo(s"Saving RocksDB files to DFS for $version") - val prevFilesToSizes = versionToRocksDBFiles.asScala.filterKeys(_ < version) - .values.flatten.map { f => - f.localFileName -> f - }.toMap var bytesCopied = 0L var filesCopied = 0L var filesReused = 0L val immutableFiles = localFiles.map { localFile => - prevFilesToSizes - .get(localFile.getName) - .filter(_.isSameFile(localFile)) - .map { reusable => - filesReused += 1 - reusable - }.getOrElse { - val localFileName = localFile.getName - val dfsFileName = newDFSFileName(localFileName) - val dfsFile = dfsFilePath(dfsFileName) - // Note: The implementation of copyFromLocalFile() closes the output stream when there is - // any exception while copying. So this may generate partial files on DFS. But that is - // okay because until the main [version].zip file is written, those partial files are - // not going to be used at all. Eventually these files should get cleared. - fs.copyFromLocalFile( - new Path(localFile.getAbsoluteFile.toURI), dfsFile) - val localFileSize = localFile.length() - logInfo(s"Copied $localFile to $dfsFile - $localFileSize bytes") - filesCopied += 1 - bytesCopied += localFileSize - - RocksDBImmutableFile(localFile.getName, dfsFileName, localFileSize) - } + val existingDfsFile = + capturedFileMappings.localFilesToDfsFiles.asScala.get(localFile.getName) + if (existingDfsFile.isDefined && existingDfsFile.get.sizeBytes == localFile.length()) { + val dfsFile = existingDfsFile.get + filesReused += 1 + logInfo(s"reusing file $dfsFile for $localFile") + RocksDBImmutableFile(localFile.getName, dfsFile.dfsFileName, dfsFile.sizeBytes) + } else { + val localFileName = localFile.getName + val dfsFileName = newDFSFileName(localFileName) + val dfsFile = dfsFilePath(dfsFileName) + // Note: The implementation of copyFromLocalFile() closes the output stream when there is + // any exception while copying. So this may generate partial files on DFS. But that is + // okay because until the main [version].zip file is written, those partial files are + // not going to be used at all. Eventually these files should get cleared. + fs.copyFromLocalFile( + new Path(localFile.getAbsoluteFile.toURI), dfsFile) + val localFileSize = localFile.length() + logInfo(s"Copied $localFile to $dfsFile - $localFileSize bytes") + filesCopied += 1 + bytesCopied += localFileSize + + val immutableDfsFile = RocksDBImmutableFile(localFile.getName, dfsFileName, localFileSize) + capturedFileMappings.localFilesToDfsFiles.put(localFileName, immutableDfsFile) + + immutableDfsFile + } } logInfo(s"Copied $filesCopied files ($bytesCopied bytes) from local to" + s" DFS for version $version. $filesReused files reused without copying.") - versionToRocksDBFiles.put(version, immutableFiles) + capturedFileMappings.versionToRocksDBFiles.put(version, immutableFiles) + + // Cleanup locally deleted files from the localFilesToDfsFiles map + // Locally, SST Files can be deleted due to RocksDB compaction. These files need + // to be removed rom the localFilesToDfsFiles map to ensure that if a older version + // regenerates them and overwrites the version.zip, SST files from the conflicting + // version (previously committed) are not reused. + removeLocallyDeletedSSTFilesFromDfsMapping(localFiles) saveCheckpointMetrics = RocksDBFileManagerMetrics( bytesCopied = bytesCopied, @@ -499,14 +539,36 @@ class RocksDBFileManager( private def loadImmutableFilesFromDfs( immutableFiles: Seq[RocksDBImmutableFile], localDir: File): Unit = { val requiredFileNameToFileDetails = immutableFiles.map(f => f.localFileName -> f).toMap + + val localImmutableFiles = listRocksDBFiles(localDir)._1 + + // Cleanup locally deleted files from the localFilesToDfsFiles map + // Locally, SST Files can be deleted due to RocksDB compaction. These files need + // to be removed rom the localFilesToDfsFiles map to ensure that if a older version + // regenerates them and overwrites the version.zip, SST files from the conflicting + // version (previously committed) are not reused. + removeLocallyDeletedSSTFilesFromDfsMapping(localImmutableFiles) + // Delete unnecessary local immutable files - listRocksDBFiles(localDir)._1 + localImmutableFiles .foreach { existingFile => - val isSameFile = - requiredFileNameToFileDetails.get(existingFile.getName).exists(_.isSameFile(existingFile)) + val existingFileSize = existingFile.length() + val requiredFile = requiredFileNameToFileDetails.get(existingFile.getName) + val prevDfsFile = fileMappings.localFilesToDfsFiles.asScala.get(existingFile.getName) + val isSameFile = if (requiredFile.isDefined && prevDfsFile.isDefined) { + requiredFile.get.dfsFileName == prevDfsFile.get.dfsFileName && + existingFile.length() == requiredFile.get.sizeBytes + } else { + false + } + if (!isSameFile) { existingFile.delete() - logInfo(s"Deleted local file $existingFile") + fileMappings.localFilesToDfsFiles.remove(existingFile.getName) + logInfo(s"Deleted local file $existingFile with size $existingFileSize mapped" + + s" to previous dfsFile ${prevDfsFile.getOrElse("null")}") + } else { + logInfo(s"reusing $prevDfsFile present at $existingFile for $requiredFile") } } @@ -532,6 +594,7 @@ class RocksDBFileManager( } filesCopied += 1 bytesCopied += localFileSize + fileMappings.localFilesToDfsFiles.put(localFileName, file) logInfo(s"Copied $dfsFile to $localFile - $localFileSize bytes") } else { filesReused += 1 @@ -546,6 +609,19 @@ class RocksDBFileManager( filesReused = filesReused) } + private def removeLocallyDeletedSSTFilesFromDfsMapping(localFiles: Seq[File]): Unit = { + // clean up deleted SST files from the localFilesToDfsFiles Map + val currentLocalFiles = localFiles.map(_.getName).toSet + val mappingsToClean = fileMappings.localFilesToDfsFiles.asScala + .keys + .filterNot(currentLocalFiles.contains) + + mappingsToClean.foreach { f => + logInfo(s"cleaning $f from the localFilesToDfsFiles map") + fileMappings.localFilesToDfsFiles.remove(f) + } + } + /** Get the SST files required for a version from the version zip file in DFS */ private def getImmutableFilesFromVersionZip(version: Long): Seq[RocksDBImmutableFile] = { Utils.deleteRecursively(localTempDir) @@ -649,6 +725,20 @@ class RocksDBFileManager( } } +/** + * Track file mappings in RocksDB across local and remote directories + * @param versionToRocksDBFiles Mapping of RocksDB files used across versions for maintenance + * @param localFilesToDfsFiles Mapping of the exact Dfs file used to create a local SST file + * The reason localFilesToDfsFiles is a separate map because versionToRocksDBFiles can contain + * multiple similar SST files to a particular local file (for example 1.sst can map to 1-UUID1.sst + * in v1 and 1-UUID2.sst in v2). We need to capture the exact file used to ensure Version ID + * compatibility across SST files and RocksDB manifest. + */ + +case class RocksDBFileMappings( + versionToRocksDBFiles: ConcurrentHashMap[Long, Seq[RocksDBImmutableFile]], + localFilesToDfsFiles: ConcurrentHashMap[String, RocksDBImmutableFile]) + /** * Metrics regarding RocksDB file sync between local and DFS. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 10f207c7ec1fe..a19eb00a7b5ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkConf, SparkEnv} @@ -202,7 +204,15 @@ private[sql] class RocksDBStateStoreProvider } override def doMaintenance(): Unit = { - rocksDB.doMaintenance() + try { + rocksDB.doMaintenance() + } catch { + // SPARK-46547 - Swallow non-fatal exception in maintenance task to avoid deadlock between + // maintenance thread and streaming aggregation operator + case NonFatal(ex) => + logWarning(s"Ignoring error while performing maintenance operations with exception=", + ex) + } } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b31f6151fce23..b597c9723f5cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -1037,10 +1037,14 @@ case class StreamingDeduplicateWithinWatermarkExec( protected val extraOptionOnStateStore: Map[String, String] = Map.empty - private val eventTimeCol: Attribute = WatermarkSupport.findEventTimeColumn(child.output, + // Below three variables are defined as lazy, as evaluating these variables does not work with + // canonicalized plan. Specifically, attributes in child won't have an event time column in + // the canonicalized plan. These variables are NOT referenced in canonicalized plan, hence + // defining these variables as lazy would avoid such error. + private lazy val eventTimeCol: Attribute = WatermarkSupport.findEventTimeColumn(child.output, allowMultipleEventTimeColumns = false).get - private val delayThresholdMs = eventTimeCol.metadata.getLong(EventTimeWatermark.delayKey) - private val eventTimeColOrdinal: Int = child.output.indexOf(eventTimeCol) + private lazy val delayThresholdMs = eventTimeCol.metadata.getLong(EventTimeWatermark.delayKey) + private lazy val eventTimeColOrdinal: Int = child.output.indexOf(eventTimeCol) protected def initializeReusedDupInfoRow(): Option[UnsafeRow] = { val timeoutToUnsafeRow = UnsafeProjection.create(schemaForValueRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 3fafc399dd828..17de4d42257b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -178,14 +178,14 @@ class SQLAppStatusListener( // work around a race in the DAGScheduler. The metrics info does not contain accumulator info // when reading event logs in the SHS, so we have to rely on the accumulator in that case. val accums = if (live && event.taskMetrics != null) { - event.taskMetrics.externalAccums.flatMap { a => + event.taskMetrics.withExternalAccums(_.flatMap { a => // This call may fail if the accumulator is gc'ed, so account for that. try { - Some(a.toInfo(Some(a.value), None)) + Some(a.toInfoUpdate) } catch { case _: IllegalAccessError => None } - } + }) } else { info.accumulables } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1504207d39cb1..668cece533353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -189,7 +189,8 @@ class SparkPlanGraphNode( } else { // SPARK-30684: when there is no metrics, add empty lines to increase the height of the node, // so that there won't be gaps between an edge and a small node. - s""" $id [labelType="html" label="
    $name

    "];""" + val escapedName = StringEscapeUtils.escapeJava(name) + s""" $id [labelType="html" label="
    $escapedName

    "];""" } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 2b7f702a7f20a..a849c3894f0d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -201,7 +201,11 @@ class FrameLessOffsetWindowFunctionFrame( override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { resetStates(rows) if (ignoreNulls) { - findNextRowWithNonNullInput() + if (Math.abs(offset) > rows.length) { + fillDefaultValue(EmptyRow) + } else { + findNextRowWithNonNullInput() + } } else { // drain the first few rows if offset is larger than zero while (inputIndex < offset) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 5543b409d1702..3a07dbf5480db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -318,7 +318,8 @@ abstract class BaseSessionStateBuilder( new AdaptiveRulesHolder( extensions.buildQueryStagePrepRules(session), extensions.buildRuntimeOptimizerRules(session), - extensions.buildQueryStageOptimizerRules(session)) + extensions.buildQueryStageOptimizerRules(session), + extensions.buildQueryPostPlannerStrategyRules(session)) } protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 796710a35672f..4e556ad846862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, FunctionCatalog, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -656,12 +657,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } else { CatalogTableType.MANAGED } - val location = if (storage.locationUri.isDefined) { - val locationStr = storage.locationUri.get.toString - Some(locationStr) - } else { - None - } + + // The location in UnresolvedTableSpec should be the original user-provided path string. + val location = CaseInsensitiveMap(options).get("path") val newOptions = OptionList(options.map { case (key, value) => (key, Literal(value).asInstanceOf[Expression]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index c246b50f4e156..336220922b5e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} -import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, MetadataBuilder, ShortType, StringType} private[sql] object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = @@ -57,6 +57,20 @@ private[sql] object H2Dialect extends JdbcDialect { override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + sqlType match { + case Types.NUMERIC if size > 38 => + // H2 supports very large decimal precision like 100000. The max precision in Spark is only + // 38. Here we shrink both the precision and scale of H2 decimal to fit Spark, and still + // keep the ratio between them. + val scale = if (null != md) md.build().getLong("scale") else 0L + val selectedScale = (DecimalType.MAX_PRECISION * (scale.toDouble / size.toDouble)).toInt + Option(DecimalType(DecimalType.MAX_PRECISION, selectedScale)) + case _ => None + } + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", Types.CLOB)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) @@ -240,6 +254,7 @@ private[sql] object H2Dialect extends JdbcDialect { } class H2SQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 2f5e813dcb618..ae8d89f0f0469 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -198,7 +198,7 @@ abstract class JdbcDialect extends Serializable with Logging { * @return The SQL query to use for checking the table. */ def getTableExistsQuery(table: String): String = { - s"SELECT * FROM $table WHERE 1=0" + s"SELECT 1 FROM $table WHERE 1=0" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 78ec3ac42d797..3022bca87a9f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection} +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.internal.SQLConf @@ -86,6 +87,20 @@ private object MsSqlServerDialect extends JdbcDialect { case "STDDEV_SAMP" => "STDEV" case _ => super.dialectFunctionName(funcName) } + + override def build(expr: Expression): String = { + // MsSqlServer does not support boolean comparison using standard comparison operators + // We shouldn't propagate these queries to MsSqlServer + expr match { + case e: Predicate => e.name() match { + case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" + if e.children().exists(_.isInstanceOf[Predicate]) => + super.visitUnexpectedExpr(expr) + case _ => super.build(expr) + } + case _ => super.build(expr) + } + } } override def compileExpression(expr: Expression): Option[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index a08c89318b660..5d9ff94838f15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference, NullOrdering, SortDirection} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} -import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder, StringType} +import org.apache.spark.sql.types._ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { @@ -65,6 +65,21 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { } } + override def visitStartsWith(l: String, r: String): String = { + val value = r.substring(1, r.length() - 1) + s"$l LIKE '${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'" + } + + override def visitEndsWith(l: String, r: String): String = { + val value = r.substring(1, r.length() - 1) + s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}' ESCAPE '\\\\'" + } + + override def visitContains(l: String, r: String): String = { + val value = r.substring(1, r.length() - 1) + s"$l LIKE '%${escapeSpecialCharsForLikePattern(value)}%' ESCAPE '\\\\'" + } + override def visitAggregateFunction( funcName: String, isDistinct: Boolean, inputs: Array[String]): String = if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { @@ -89,10 +104,15 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // MariaDB connector behaviour // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. md.putLong("binarylong", 1) Option(LongType) + } else if (sqlType == Types.BIT && size > 1) { + // MySQL connector behaviour + md.putLong("binarylong", 1) + Option(LongType) } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { Option(BooleanType) } else if ("TINYTEXT".equalsIgnoreCase(typeName)) { @@ -102,6 +122,12 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { // Some MySQL JDBC drivers converts JSON type into Types.VARCHAR with a precision of -1. // Explicitly converts it into StringType here. Some(StringType) + } else if (sqlType == Types.TINYINT) { + if (md.build().getBoolean("isSigned")) { + Some(ByteType) + } else { + Some(ShortType) + } } else None } @@ -128,10 +154,6 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { schemaBuilder.result } - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) // See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html @@ -184,6 +206,7 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { // We override getJDBCType so that FloatType is mapped to FLOAT instead case FloatType => Option(JdbcType("FLOAT", java.sql.Types.FLOAT)) case StringType => Option(JdbcType("LONGTEXT", java.sql.Types.LONGVARCHAR)) + case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT)) case _ => JdbcUtils.getCommonJDBCType(dt) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 3a0333cca33fd..95774d38e50ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -118,7 +118,7 @@ private case object OracleDialect extends JdbcDialect { case DoubleType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE)) case ByteType => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT)) case ShortType => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT)) - case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) + case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 9c1ca2cb913e6..f8f72d88589e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -129,10 +129,6 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { case _ => None } - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) /** @@ -297,8 +293,9 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { val POSTGRESQL_DATE_POSITIVE_INFINITY = 9223372036825200000L val POSTGRESQL_DATE_DATE_POSITIVE_SMALLER_INFINITY = 185543533774800000L - val minTimeStamp = LocalDateTime.of(1, 1, 1, 0, 0, 0).toEpochSecond(ZoneOffset.UTC) - val maxTimestamp = LocalDateTime.of(9999, 12, 31, 23, 59, 59).toEpochSecond(ZoneOffset.UTC) + val minTimeStamp = LocalDateTime.of(1, 1, 1, 0, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli + val maxTimestamp = + LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999999999).toInstant(ZoneOffset.UTC).toEpochMilli val time = t.getTime if (time == POSTGRESQL_DATE_POSITIVE_INFINITY || diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala index 7cd7db4088ac9..ce3e7cde01b7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPage.scala @@ -174,7 +174,7 @@ private[ui] class StreamingQueryPagedTable( override def row(query: StructuredStreamingRow): Seq[Node] = { val streamingQuery = query.streamingUIData - val statisticsLink = "%s/%s/statistics?id=%s" + val statisticsLink = "%s/%s/statistics/?id=%s" .format(SparkUIUtils.prependBaseUri(request, parent.basePath), parent.prefix, streamingQuery.summary.runId) diff --git a/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala b/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala index 3c96f612da6bb..fa5bea5f9bbe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala +++ b/sql/core/src/main/scala/org/apache/spark/status/api/v1/sql/SqlResource.scala @@ -56,10 +56,9 @@ private[v1] class SqlResource extends BaseAppResource { planDescription: Boolean): ExecutionData = { withUI { ui => val sqlStore = new SQLAppStatusStore(ui.store.store) - val graph = sqlStore.planGraph(execId) sqlStore .execution(execId) - .map(prepareExecutionData(_, graph, details, planDescription)) + .map(prepareExecutionData(_, sqlStore.planGraph(execId), details, planDescription)) .getOrElse(throw new NotFoundException("unknown query execution id: " + execId)) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 4f7cf8da78722..f416d411322ee 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1783,6 +1783,23 @@ public void testEmptyBean() { Assert.assertEquals(1, df.collectAsList().size()); } + public static class ReadOnlyPropertyBean implements Serializable { + public boolean isEmpty() { + return true; + } + } + + @Test + public void testReadOnlyPropertyBean() { + ReadOnlyPropertyBean bean = new ReadOnlyPropertyBean(); + List data = Arrays.asList(bean); + Dataset df = spark.createDataset(data, + Encoders.bean(ReadOnlyPropertyBean.class)); + Assert.assertEquals(1, df.schema().length()); + Assert.assertEquals(1, df.collectAsList().size()); + + } + public class CircularReference1Bean implements Serializable { private CircularReference2Bean child; diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out index cd101c7a524a1..3b196ea93e40c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out @@ -531,6 +531,13 @@ Project [array_insert(array(2, 3, cast(null as int), 4), -5, 1, false) AS array_ +- OneRowRelation +-- !query +select array_insert(array(1), 2, cast(2 as tinyint)) +-- !query analysis +Project [array_insert(array(1), 2, cast(cast(2 as tinyint) as int), false) AS array_insert(array(1), 2, CAST(2 AS TINYINT))#x] ++- OneRowRelation + + -- !query set spark.sql.legacy.negativeIndexInArrayInsert=true -- !query analysis @@ -740,3 +747,17 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) -- !query analysis Project [array_prepend(array(cast(null as string)), cast(null as string)) AS array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING))#x] +- OneRowRelation + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query analysis +Project [array_union(array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double)), array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double))) AS array_union(array(0.0, 0.0, NaN), array(0.0, 0.0, NaN))#x] ++- OneRowRelation + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query analysis +Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/double-quoted-identifiers-enabled.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/double-quoted-identifiers-enabled.sql.out index 0a009a3a282f9..b45e461264e27 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/double-quoted-identifiers-enabled.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/double-quoted-identifiers-enabled.sql.out @@ -434,7 +434,7 @@ Project [a1#x AS a2#x] : +- OneRowRelation +- Project [a#x] +- SubqueryAlias v - +- CTERelationRef xxxx, true, [a#x] + +- CTERelationRef xxxx, true, [a#x], false -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out index 08d3be615b314..8fe6e7097e67e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/higher-order-functions.sql.out @@ -34,6 +34,25 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out index 1120c40ac15c4..2e2a07beb7176 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out @@ -1916,7 +1916,7 @@ org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "_LEGACY_ERROR_TEMP_0063", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval year to month: -\t2-2\t" + "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out index 53c7327c58717..001dd4d644873 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/literals.sql.out @@ -692,3 +692,10 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query analysis +Project [0 AS 0#x, 0.0 AS 0.0#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out index bbc07c22805a6..15fe614ff0d22 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out @@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x] +- OneRowRelation +-- !query +SELECT try_add(2147483647, decimal(1)) +-- !query analysis +Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x] ++- OneRowRelation + + +-- !query +SELECT try_add(2147483647, "1") +-- !query analysis +Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#xL] ++- OneRowRelation + + -- !query SELECT try_add(-2147483648, -1) -- !query analysis @@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x] +- OneRowRelation +-- !query +SELECT try_divide(1, decimal(0)) +-- !query analysis +Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x] ++- OneRowRelation + + +-- !query +SELECT try_divide(1, "0") +-- !query analysis +Project [try_divide(1, 0) AS try_divide(1, 0)#x] ++- OneRowRelation + + -- !query SELECT try_divide(interval 2 year, 2) -- !query analysis @@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x] +- OneRowRelation +-- !query +SELECT try_subtract(2147483647, decimal(-1)) +-- !query analysis +Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x] ++- OneRowRelation + + +-- !query +SELECT try_subtract(2147483647, "-1") +-- !query analysis +Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#xL] ++- OneRowRelation + + -- !query SELECT try_subtract(-2147483648, 1) -- !query analysis @@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x] +- OneRowRelation +-- !query +SELECT try_multiply(2147483647, decimal(-2)) +-- !query analysis +Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x] ++- OneRowRelation + + +-- !query +SELECT try_multiply(2147483647, "-2") +-- !query analysis +Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#xL] ++- OneRowRelation + + -- !query SELECT try_multiply(-2147483648, 2) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out index 8279fb3362e54..ca6c89bfadc3d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out @@ -531,6 +531,13 @@ Project [array_insert(array(2, 3, cast(null as int), 4), -5, 1, false) AS array_ +- OneRowRelation +-- !query +select array_insert(array(1), 2, cast(2 as tinyint)) +-- !query analysis +Project [array_insert(array(1), 2, cast(cast(2 as tinyint) as int), false) AS array_insert(array(1), 2, CAST(2 AS TINYINT))#x] ++- OneRowRelation + + -- !query set spark.sql.legacy.negativeIndexInArrayInsert=true -- !query analysis @@ -740,3 +747,17 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) -- !query analysis Project [array_prepend(array(cast(null as string)), cast(null as string)) AS array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING))#x] +- OneRowRelation + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query analysis +Project [array_union(array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double)), array(cast(0.0 as double), cast(0.0 as double), cast(NaN as double))) AS array_union(array(0.0, 0.0, NaN), array(0.0, 0.0, NaN))#x] ++- OneRowRelation + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query analysis +Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out index 6e72fd28686a0..544d736b56b64 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out @@ -255,6 +255,18 @@ desc formatted char_part DescribeTableCommand `spark_catalog`.`default`.`char_part`, true, [col_name#x, data_type#x, comment#x] +-- !query +alter table char_part change column c1 comment 'char comment' +-- !query analysis +AlterTableChangeColumnCommand `spark_catalog`.`default`.`char_part`, c1, StructField(c1,CharType(5),true) + + +-- !query +alter table char_part change column v1 comment 'varchar comment' +-- !query analysis +AlterTableChangeColumnCommand `spark_catalog`.`default`.`char_part`, v1, StructField(v1,VarcharType(6),true) + + -- !query alter table char_part add partition (v2='ke', c2='nt') location 'loc1' -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out index d96965edde136..de0e6dfae2ce3 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out @@ -15,10 +15,10 @@ WithCTE : +- SubqueryAlias t : +- Project [1#x] : +- SubqueryAlias t2 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -37,7 +37,7 @@ Aggregate [max(c#x) AS max(c)#x] : +- OneRowRelation +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -54,7 +54,7 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x] : : +- OneRowRelation : +- Project [1#x] : +- SubqueryAlias t -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- OneRowRelation @@ -136,11 +136,11 @@ WithCTE : : : +- OneRowRelation : : +- Project [c#x] : : +- SubqueryAlias t -: : +- CTERelationRef xxxx, true, [c#x] +: : +- CTERelationRef xxxx, true, [c#x], false : +- OneRowRelation +- Project [scalarsubquery()#x] +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [scalarsubquery()#x] + +- CTERelationRef xxxx, true, [scalarsubquery()#x], false -- !query @@ -189,7 +189,7 @@ WithCTE +- SubqueryAlias __auto_generated_subquery_name +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -218,7 +218,7 @@ WithCTE +- SubqueryAlias __auto_generated_subquery_name +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -253,7 +253,7 @@ WithCTE +- SubqueryAlias __auto_generated_subquery_name +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -352,14 +352,14 @@ WithCTE : +- SubqueryAlias t : +- Project [1#x] : +- SubqueryAlias t2 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias t2 : +- Project [2 AS 2#x] : +- OneRowRelation +- Project [1#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -420,15 +420,15 @@ WithCTE : +- SubqueryAlias t3 : +- Project [1#x] : +- SubqueryAlias t1 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias t2 : +- Project [1#x] : +- SubqueryAlias t3 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -451,12 +451,12 @@ WithCTE : +- SubqueryAlias cte_inner : +- Project [1#x] : +- SubqueryAlias cte_outer -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [1#x] +- SubqueryAlias cte_inner - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -484,19 +484,19 @@ WithCTE : +- SubqueryAlias cte_inner_inner : +- Project [1#x] : +- SubqueryAlias cte_outer -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias cte_inner : +- Project [1#x] : +- SubqueryAlias __auto_generated_subquery_name : +- Project [1#x] : +- SubqueryAlias cte_inner_inner -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [1#x] +- SubqueryAlias cte_inner - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out index bd9b443d01d0a..f1a302b06f2a8 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out @@ -15,10 +15,10 @@ WithCTE : +- SubqueryAlias t : +- Project [1#x] : +- SubqueryAlias t2 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -37,7 +37,7 @@ Aggregate [max(c#x) AS max(c)#x] : +- OneRowRelation +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -54,7 +54,7 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x] : : +- OneRowRelation : +- Project [1#x] : +- SubqueryAlias t -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- OneRowRelation @@ -106,10 +106,10 @@ WithCTE : +- SubqueryAlias t2 : +- Project [2#x] : +- SubqueryAlias t -: +- CTERelationRef xxxx, true, [2#x] +: +- CTERelationRef xxxx, true, [2#x], false +- Project [2#x] +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [2#x] + +- CTERelationRef xxxx, true, [2#x], false -- !query @@ -144,11 +144,11 @@ WithCTE : : : +- OneRowRelation : : +- Project [c#x] : : +- SubqueryAlias t -: : +- CTERelationRef xxxx, true, [c#x] +: : +- CTERelationRef xxxx, true, [c#x], false : +- OneRowRelation +- Project [scalarsubquery()#x] +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [scalarsubquery()#x] + +- CTERelationRef xxxx, true, [scalarsubquery()#x], false -- !query @@ -181,15 +181,15 @@ WithCTE : +- SubqueryAlias t2 : +- Project [3#x] : +- SubqueryAlias t -: +- CTERelationRef xxxx, true, [3#x] +: +- CTERelationRef xxxx, true, [3#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias t2 : +- Project [3#x] : +- SubqueryAlias t2 -: +- CTERelationRef xxxx, true, [3#x] +: +- CTERelationRef xxxx, true, [3#x], false +- Project [3#x] +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [3#x] + +- CTERelationRef xxxx, true, [3#x], false -- !query @@ -214,7 +214,7 @@ WithCTE +- SubqueryAlias __auto_generated_subquery_name +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -243,7 +243,7 @@ WithCTE +- SubqueryAlias __auto_generated_subquery_name +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -278,7 +278,7 @@ WithCTE +- SubqueryAlias __auto_generated_subquery_name +- Project [c#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -301,7 +301,7 @@ WithCTE : : +- OneRowRelation : +- Project [2#x] : +- SubqueryAlias t - : +- CTERelationRef xxxx, true, [2#x] + : +- CTERelationRef xxxx, true, [2#x], false +- OneRowRelation @@ -328,7 +328,7 @@ WithCTE : : : +- OneRowRelation : : +- Project [2#x] : : +- SubqueryAlias t - : : +- CTERelationRef xxxx, true, [2#x] + : : +- CTERelationRef xxxx, true, [2#x], false : +- OneRowRelation +- OneRowRelation @@ -362,7 +362,7 @@ WithCTE : : : +- OneRowRelation : : +- Project [3#x] : : +- SubqueryAlias t - : : +- CTERelationRef xxxx, true, [3#x] + : : +- CTERelationRef xxxx, true, [3#x], false : +- OneRowRelation +- OneRowRelation @@ -391,9 +391,9 @@ WithCTE : : +- OneRowRelation : +- Project [c#x] : +- SubqueryAlias t - : +- CTERelationRef xxxx, true, [c#x] + : +- CTERelationRef xxxx, true, [c#x], false +- SubqueryAlias t - +- CTERelationRef xxxx, true, [c#x] + +- CTERelationRef xxxx, true, [c#x], false -- !query @@ -414,14 +414,14 @@ WithCTE : +- SubqueryAlias t : +- Project [1#x] : +- SubqueryAlias t2 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias t2 : +- Project [2 AS 2#x] : +- OneRowRelation +- Project [1#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -446,10 +446,10 @@ WithCTE : +- SubqueryAlias t : +- Project [2#x] : +- SubqueryAlias aBC -: +- CTERelationRef xxxx, true, [2#x] +: +- CTERelationRef xxxx, true, [2#x], false +- Project [2#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [2#x] + +- CTERelationRef xxxx, true, [2#x], false -- !query @@ -472,7 +472,7 @@ WithCTE : : +- OneRowRelation : +- Project [2#x] : +- SubqueryAlias aBC - : +- CTERelationRef xxxx, true, [2#x] + : +- CTERelationRef xxxx, true, [2#x], false +- OneRowRelation @@ -496,15 +496,15 @@ WithCTE : +- SubqueryAlias t3 : +- Project [1#x] : +- SubqueryAlias t1 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias t2 : +- Project [1#x] : +- SubqueryAlias t3 -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -527,12 +527,12 @@ WithCTE : +- SubqueryAlias cte_inner : +- Project [1#x] : +- SubqueryAlias cte_outer -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [1#x] +- SubqueryAlias cte_inner - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -560,19 +560,19 @@ WithCTE : +- SubqueryAlias cte_inner_inner : +- Project [1#x] : +- SubqueryAlias cte_outer -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias cte_inner : +- Project [1#x] : +- SubqueryAlias __auto_generated_subquery_name : +- Project [1#x] : +- SubqueryAlias cte_inner_inner -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- Project [1#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [1#x] +- SubqueryAlias cte_inner - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte.sql.out index b9a0f776528d8..e817aaf9e59ff 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte.sql.out @@ -73,7 +73,7 @@ WithCTE : +- LocalRelation [id#x] +- Project [1#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [1#x] + +- CTERelationRef xxxx, true, [1#x], false -- !query @@ -113,13 +113,13 @@ WithCTE : +- SubqueryAlias t2 : +- Project [2 AS 2#x] : +- SubqueryAlias t1 -: +- CTERelationRef xxxx, true, [id#x] +: +- CTERelationRef xxxx, true, [id#x], false +- Project [id#x, 2#x] +- Join Cross :- SubqueryAlias t1 - : +- CTERelationRef xxxx, true, [id#x] + : +- CTERelationRef xxxx, true, [id#x], false +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [2#x] + +- CTERelationRef xxxx, true, [2#x], false -- !query @@ -157,10 +157,10 @@ WithCTE +- Join Cross :- SubqueryAlias t1 : +- SubqueryAlias CTE1 - : +- CTERelationRef xxxx, true, [id#x] + : +- CTERelationRef xxxx, true, [id#x], false +- SubqueryAlias t2 +- SubqueryAlias CTE1 - +- CTERelationRef xxxx, true, [id#x] + +- CTERelationRef xxxx, true, [id#x], false -- !query @@ -176,7 +176,7 @@ WithCTE +- Project [x#x] +- Filter (x#x = 1) +- SubqueryAlias t - +- CTERelationRef xxxx, true, [x#x] + +- CTERelationRef xxxx, true, [x#x], false -- !query @@ -192,7 +192,7 @@ WithCTE +- Project [x#x, y#x] +- Filter ((x#x = 1) AND (y#x = 2)) +- SubqueryAlias t - +- CTERelationRef xxxx, true, [x#x, y#x] + +- CTERelationRef xxxx, true, [x#x, y#x], false -- !query @@ -207,7 +207,7 @@ WithCTE : +- OneRowRelation +- Project [x#x, x#x] +- SubqueryAlias t - +- CTERelationRef xxxx, true, [x#x, x#x] + +- CTERelationRef xxxx, true, [x#x, x#x], false -- !query @@ -309,46 +309,46 @@ WithCTE : +- Project [c8#x AS c7#x] : +- Project [c8#x] : +- SubqueryAlias w8 -: +- CTERelationRef xxxx, true, [c8#x] +: +- CTERelationRef xxxx, true, [c8#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias w6 : +- Project [c7#x AS c6#x] : +- Project [c7#x] : +- SubqueryAlias w7 -: +- CTERelationRef xxxx, true, [c7#x] +: +- CTERelationRef xxxx, true, [c7#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias w5 : +- Project [c6#x AS c5#x] : +- Project [c6#x] : +- SubqueryAlias w6 -: +- CTERelationRef xxxx, true, [c6#x] +: +- CTERelationRef xxxx, true, [c6#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias w4 : +- Project [c5#x AS c4#x] : +- Project [c5#x] : +- SubqueryAlias w5 -: +- CTERelationRef xxxx, true, [c5#x] +: +- CTERelationRef xxxx, true, [c5#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias w3 : +- Project [c4#x AS c3#x] : +- Project [c4#x] : +- SubqueryAlias w4 -: +- CTERelationRef xxxx, true, [c4#x] +: +- CTERelationRef xxxx, true, [c4#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias w2 : +- Project [c3#x AS c2#x] : +- Project [c3#x] : +- SubqueryAlias w3 -: +- CTERelationRef xxxx, true, [c3#x] +: +- CTERelationRef xxxx, true, [c3#x], false :- CTERelationDef xxxx, false : +- SubqueryAlias w1 : +- Project [c2#x AS c1#x] : +- Project [c2#x] : +- SubqueryAlias w2 -: +- CTERelationRef xxxx, true, [c2#x] +: +- CTERelationRef xxxx, true, [c2#x], false +- Project [c1#x] +- SubqueryAlias w1 - +- CTERelationRef xxxx, true, [c1#x] + +- CTERelationRef xxxx, true, [c1#x], false -- !query @@ -384,7 +384,7 @@ WithCTE +- Project [42#x, 10#x] +- Join Inner :- SubqueryAlias same_name - : +- CTERelationRef xxxx, true, [42#x] + : +- CTERelationRef xxxx, true, [42#x], false +- SubqueryAlias same_name +- Project [10 AS 10#x] +- OneRowRelation @@ -423,7 +423,7 @@ WithCTE : +- OneRowRelation +- Project [x#x, typeof(x#x) AS typeof(x)#x] +- SubqueryAlias q - +- CTERelationRef xxxx, true, [x#x] + +- CTERelationRef xxxx, true, [x#x], false -- !query @@ -483,7 +483,7 @@ Project [y#x] : +- OneRowRelation +- Project [(x#x + 1) AS y#x] +- SubqueryAlias q - +- CTERelationRef xxxx, true, [x#x] + +- CTERelationRef xxxx, true, [x#x], false -- !query @@ -497,7 +497,7 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x] : : +- OneRowRelation : +- Project [x#x] : +- SubqueryAlias q -: +- CTERelationRef xxxx, true, [x#x] +: +- CTERelationRef xxxx, true, [x#x], false +- OneRowRelation @@ -512,7 +512,7 @@ Project [1 IN (list#x []) AS (1 IN (listquery()))#x] : : +- OneRowRelation : +- Project [1#x] : +- SubqueryAlias q -: +- CTERelationRef xxxx, true, [1#x] +: +- CTERelationRef xxxx, true, [1#x], false +- OneRowRelation @@ -560,14 +560,14 @@ WithCTE :- Join Inner : :- SubqueryAlias x : : +- SubqueryAlias T1 - : : +- CTERelationRef xxxx, true, [a#x] + : : +- CTERelationRef xxxx, true, [a#x], false : +- SubqueryAlias y : +- Project [b#x] : +- SubqueryAlias T1 - : +- CTERelationRef xxxx, true, [b#x] + : +- CTERelationRef xxxx, true, [b#x], false +- SubqueryAlias z +- SubqueryAlias T1 - +- CTERelationRef xxxx, true, [a#x] + +- CTERelationRef xxxx, true, [a#x], false -- !query @@ -595,9 +595,9 @@ WithCTE +- Project [c#x, a#x] +- Join Inner :- SubqueryAlias ttTT - : +- CTERelationRef xxxx, true, [c#x] + : +- CTERelationRef xxxx, true, [c#x], false +- SubqueryAlias tttT_2 - +- CTERelationRef xxxx, true, [a#x] + +- CTERelationRef xxxx, true, [a#x], false -- !query @@ -613,7 +613,7 @@ Project [scalar-subquery#x [x#x] AS scalarsubquery(x)#x] : : +- OneRowRelation : +- Project [x#x] : +- SubqueryAlias q -: +- CTERelationRef xxxx, true, [x#x] +: +- CTERelationRef xxxx, true, [x#x], false +- SubqueryAlias T +- Project [1 AS x#x, 2 AS y#x] +- OneRowRelation @@ -632,7 +632,7 @@ Project [scalar-subquery#x [x#x && y#x] AS scalarsubquery(x, y)#x] : : +- OneRowRelation : +- Project [((outer(x#x) + outer(y#x)) + z#x) AS ((outer(T.x) + outer(T.y)) + z)#x] : +- SubqueryAlias q -: +- CTERelationRef xxxx, true, [z#x] +: +- CTERelationRef xxxx, true, [z#x], false +- SubqueryAlias T +- Project [1 AS x#x, 2 AS y#x] +- OneRowRelation @@ -652,12 +652,12 @@ WithCTE : +- SubqueryAlias q2 : +- Project [x#x] : +- SubqueryAlias q1 -: +- CTERelationRef xxxx, true, [x#x] +: +- CTERelationRef xxxx, true, [x#x], false +- Project [x#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [x#x] +- SubqueryAlias q2 - +- CTERelationRef xxxx, true, [x#x] + +- CTERelationRef xxxx, true, [x#x], false -- !query @@ -674,12 +674,12 @@ WithCTE : +- SubqueryAlias q1 : +- Project [(x#x + 1) AS (x + 1)#x] : +- SubqueryAlias q1 -: +- CTERelationRef xxxx, true, [x#x] +: +- CTERelationRef xxxx, true, [x#x], false +- Project [(x + 1)#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [(x + 1)#x] +- SubqueryAlias q1 - +- CTERelationRef xxxx, true, [(x + 1)#x] + +- CTERelationRef xxxx, true, [(x + 1)#x], false -- !query @@ -720,9 +720,9 @@ WithCTE : +- Aggregate [max(j#x) AS max(j)#x] : +- SubqueryAlias cte2 : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [j#x] + : +- CTERelationRef xxxx, true, [j#x], false +- SubqueryAlias cte1 - +- CTERelationRef xxxx, true, [j#x] + +- CTERelationRef xxxx, true, [j#x], false -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out index f37e31bdb389c..522cfb0cbbd28 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out @@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val), None)], For -- !query EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 -- !query analysis -ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode +ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out index f37e31bdb389c..522cfb0cbbd28 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out @@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val), None)], For -- !query EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 -- !query analysis -ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode +ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out index c8c34a856d492..1bcde5bd367f7 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by-ordinal.sql.out @@ -61,7 +61,7 @@ Aggregate [a#x, a#x], [a#x, 1 AS 1#x, sum(b#x) AS sum(b)#xL] -- !query select a, 1, sum(b) from data group by 1, 2 -- !query analysis -Aggregate [a#x, 1], [a#x, 1 AS 1#x, sum(b#x) AS sum(b)#xL] +Aggregate [a#x, 2], [a#x, 1 AS 1#x, sum(b#x) AS sum(b)#xL] +- SubqueryAlias data +- View (`data`, [a#x,b#x]) +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out index 202ceee18046a..93c463575dc1a 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out @@ -1196,3 +1196,22 @@ Aggregate [c#x], [(c#x * 2) AS d#x] +- Project [if ((a#x < 0)) 0 else a#x AS b#x] +- SubqueryAlias t1 +- LocalRelation [a#x] + + +-- !query +SELECT col1, count(*) AS cnt +FROM VALUES + (0.0), + (-0.0), + (double('NaN')), + (double('NaN')), + (double('Infinity')), + (double('Infinity')), + (-double('Infinity')), + (-double('Infinity')) +GROUP BY col1 +ORDER BY col1 +-- !query analysis +Sort [col1#x ASC NULLS FIRST], true ++- Aggregate [col1#x], [col1#x, count(1) AS cnt#xL] + +- LocalRelation [col1#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out index f656716a843e0..d851019860789 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/higher-order-functions.sql.out @@ -34,6 +34,25 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out index 00e2d8ff8ae75..ecab824f0995d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out @@ -187,10 +187,11 @@ Project [coalesce(cast(null as int), 1) AS coalesce(NULL, 1)#x] -- !query -SELECT IDENTIFIER('abs')(-1) +SELECT IDENTIFIER('abs')(c1) FROM VALUES(-1) AS T(c1) -- !query analysis -Project [abs(-1) AS abs(-1)#x] -+- OneRowRelation +Project [abs(c1#x) AS abs(c1)#x] ++- SubqueryAlias T + +- LocalRelation [c1#x] -- !query @@ -665,7 +666,28 @@ org.apache.spark.sql.AnalysisException -- !query -CREATE TABLE IDENTIFIER(1)(c1 INT) +SELECT `IDENTIFIER`('abs')(c1) FROM VALUES(-1) AS T(c1) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_ROUTINE", + "sqlState" : "42883", + "messageParameters" : { + "routineName" : "`IDENTIFIER`", + "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "`IDENTIFIER`('abs')" + } ] +} + + +-- !query +CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv -- !query analysis org.apache.spark.sql.AnalysisException { @@ -687,7 +709,7 @@ org.apache.spark.sql.AnalysisException -- !query -CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) +CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) USING csv -- !query analysis org.apache.spark.sql.AnalysisException { @@ -859,6 +881,65 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CreateViewCommand `v1`, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, false, LocalTempView, true + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CacheTableAsSelect t1, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, true + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +create table identifier('t2') using csv as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t2`, ErrorIfExists, [my_col] + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1 +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t2, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/t2], Append, `spark_catalog`.`default`.`t2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t2), [my_col] ++- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +drop view v1 +-- !query analysis +DropTempViewCommand v1 + + +-- !query +drop table t1 +-- !query analysis +DropTempViewCommand t1 + + +-- !query +drop table t2 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2 + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out index 2a17f092a06b7..adce16bf23578 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out @@ -73,9 +73,7 @@ Project [a#x, b#x] -- !query select a from values ("one", current_timestamp) as data(a, b) -- !query analysis -Project [a#x] -+- SubqueryAlias data - +- LocalRelation [a#x, b#x] +[Analyzer test output redacted due to nondeterminism] -- !query @@ -241,3 +239,15 @@ select * from values (10 + try_divide(5, 0)) -- !query analysis Project [col1#x] +- LocalRelation [col1#x] + + +-- !query +select count(distinct ct) from values now(), now(), now() as data(ct) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +select count(distinct ct) from values current_timestamp(), current_timestamp() as data(ct) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out index 337edd5980c39..6242dc142eabb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out @@ -1916,7 +1916,7 @@ org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "_LEGACY_ERROR_TEMP_0063", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval year to month: -\t2-2\t" + "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out index 4c032b7cbf9a2..2c7b31f62c6f4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out @@ -1310,10 +1310,10 @@ WithCTE : : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] : : +- LocalRelation [col1#x, col2#x] : +- SubqueryAlias cte1 -: +- CTERelationRef xxxx, true, [c1#x] +: +- CTERelationRef xxxx, true, [c1#x], false +- Project [c1#x, c2#x] +- SubqueryAlias cte2 - +- CTERelationRef xxxx, true, [c1#x, c2#x] + +- CTERelationRef xxxx, true, [c1#x, c2#x], false -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out index 7c91139921b58..a791a4f35e98a 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out @@ -1,11 +1,11 @@ -- Automatically generated by SQLQueryTestSuite -- !query CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x) -- !query analysis CreateViewCommand `testRegression`, SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x), false, true, LocalTempView, true +- Project [k#x, y#x, x#x] +- SubqueryAlias testRegression diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out index 53c7327c58717..001dd4d644873 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/literals.sql.out @@ -692,3 +692,10 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query analysis +Project [0 AS 0#x, 0.0 AS 0.0#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/non-excludable-rule.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/non-excludable-rule.sql.out index 305a59f01e443..b80bed6f7c2aa 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/non-excludable-rule.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/non-excludable-rule.sql.out @@ -47,7 +47,7 @@ WithCTE +- Filter (id#xL > scalar-subquery#x []) : +- Aggregate [max(id#xL) AS max(id)#xL] : +- SubqueryAlias tmp - : +- CTERelationRef xxxx, true, [id#xL] + : +- CTERelationRef xxxx, true, [id#xL], false +- Range (0, 3, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out index b199cb55f2a44..7f477c80d46ca 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/create_view.sql.out @@ -1661,7 +1661,7 @@ select * from tt7a left join tt8a using (x), tt8a tt8ax, false, false, Persisted :- Project [a#x, b#x, c#x, d#x, e#x] : +- SubqueryAlias v : +- Project [col1#x AS a#x, col2#x AS b#x, col3#x AS c#x, col4#x AS d#x, col5#x AS e#x] - : +- LocalRelation [col1#x, col2#x, col3#x, col4#x, col5#x] + : +- ResolvedInlineTable [[now(), 2, 3, now(), 5]], [col1#x, col2#x, col3#x, col4#x, col5#x] +- Project [cast(x#x as timestamp) AS x#x, y#x, z#x, x#x, z#x] +- Project [x#x, y#x, z#x, x#x, z#x] +- Join Inner diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out index 6b6a37b4e7fb4..6698d1fb083f0 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part3.sql.out @@ -98,7 +98,7 @@ WithCTE +- Window [sum(x#xL) windowspecdefinition(x#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, 1)) AS sum(x) OVER (ORDER BY x ASC NULLS FIRST ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)#xL], [x#xL ASC NULLS FIRST] +- Project [x#xL] +- SubqueryAlias cte - +- CTERelationRef xxxx, true, [x#xL] + +- CTERelationRef xxxx, true, [x#xL], false -- !query @@ -120,7 +120,7 @@ WithCTE +- Window [sum(x#xL) windowspecdefinition(x#xL ASC NULLS FIRST, specifiedwindowframe(RangeFrame, cast(-1 as bigint), cast(1 as bigint))) AS sum(x) OVER (ORDER BY x ASC NULLS FIRST RANGE BETWEEN (- 1) FOLLOWING AND 1 FOLLOWING)#xL], [x#xL ASC NULLS FIRST] +- Project [x#xL] +- SubqueryAlias cte - +- CTERelationRef xxxx, true, [x#xL] + +- CTERelationRef xxxx, true, [x#xL], false -- !query @@ -153,7 +153,7 @@ WithCTE +- Window [sum(x#xL) windowspecdefinition(x#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, 1)) AS sum(x) OVER (ORDER BY x ASC NULLS FIRST ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)#xL], [x#xL ASC NULLS FIRST] +- Project [x#xL] +- SubqueryAlias cte - +- CTERelationRef xxxx, true, [x#xL] + +- CTERelationRef xxxx, true, [x#xL], false -- !query @@ -186,7 +186,7 @@ WithCTE +- Window [sum(x#xL) windowspecdefinition(x#xL ASC NULLS FIRST, specifiedwindowframe(RangeFrame, cast(-1 as bigint), cast(1 as bigint))) AS sum(x) OVER (ORDER BY x ASC NULLS FIRST RANGE BETWEEN (- 1) FOLLOWING AND 1 FOLLOWING)#xL], [x#xL ASC NULLS FIRST] +- Project [x#xL] +- SubqueryAlias cte - +- CTERelationRef xxxx, true, [x#xL] + +- CTERelationRef xxxx, true, [x#xL], false -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out index c978c583152c5..b3ce967f2a6b5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/with.sql.out @@ -12,10 +12,10 @@ WithCTE +- Project [x#x, y#x, x#x, y#x] +- Join Inner :- SubqueryAlias q1 - : +- CTERelationRef xxxx, true, [x#x, y#x] + : +- CTERelationRef xxxx, true, [x#x, y#x], false +- SubqueryAlias q2 +- SubqueryAlias q1 - +- CTERelationRef xxxx, true, [x#x, y#x] + +- CTERelationRef xxxx, true, [x#x, y#x], false -- !query @@ -194,7 +194,7 @@ WithCTE +- SubqueryAlias q +- Project [foo#x] +- SubqueryAlias cte - +- CTERelationRef xxxx, true, [foo#x] + +- CTERelationRef xxxx, true, [foo#x], false -- !query @@ -222,13 +222,13 @@ WithCTE : +- Union false, false : :- Project [2#x] : : +- SubqueryAlias innermost -: : +- CTERelationRef xxxx, true, [2#x] +: : +- CTERelationRef xxxx, true, [2#x], false : +- Project [3 AS 3#x] : +- OneRowRelation +- Sort [x#x ASC NULLS FIRST], true +- Project [x#x] +- SubqueryAlias outermost - +- CTERelationRef xxxx, true, [x#x] + +- CTERelationRef xxxx, true, [x#x], false -- !query @@ -418,7 +418,7 @@ WithCTE : +- OneRowRelation +- Project [x#x] +- SubqueryAlias ordinality - +- CTERelationRef xxxx, true, [x#x] + +- CTERelationRef xxxx, true, [x#x], false -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-cte.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-cte.sql.out index 2cd6ba5356371..cab83b2649974 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/exists-subquery/exists-cte.sql.out @@ -133,7 +133,7 @@ WithCTE : +- Filter (outer(emp_name#x) = emp_name#x) : +- SubqueryAlias b : +- SubqueryAlias bonus_cte - : +- CTERelationRef xxxx, true, [emp_name#x, bonus_amt#x] + : +- CTERelationRef xxxx, true, [emp_name#x, bonus_amt#x], false +- SubqueryAlias a +- SubqueryAlias bonus +- View (`BONUS`, [emp_name#x,bonus_amt#x]) @@ -189,10 +189,10 @@ WithCTE : +- Join Inner, (dept_id#x = dept_id#x) : :- SubqueryAlias a : : +- SubqueryAlias emp_cte - : : +- CTERelationRef xxxx, true, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] + : : +- CTERelationRef xxxx, true, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x], false : +- SubqueryAlias b : +- SubqueryAlias dept_cte - : +- CTERelationRef xxxx, true, [dept_id#x, dept_name#x, state#x] + : +- CTERelationRef xxxx, true, [dept_id#x, dept_name#x, state#x], false +- SubqueryAlias bonus +- View (`BONUS`, [emp_name#x,bonus_amt#x]) +- Project [cast(emp_name#x as string) AS emp_name#x, cast(bonus_amt#x as double) AS bonus_amt#x] @@ -253,10 +253,10 @@ WithCTE : +- Join LeftOuter, (dept_id#x = dept_id#x) : :- SubqueryAlias a : : +- SubqueryAlias emp_cte - : : +- CTERelationRef xxxx, true, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] + : : +- CTERelationRef xxxx, true, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x], false : +- SubqueryAlias b : +- SubqueryAlias dept_cte - : +- CTERelationRef xxxx, true, [dept_id#x, dept_name#x, state#x] + : +- CTERelationRef xxxx, true, [dept_id#x, dept_name#x, state#x], false +- Join Inner :- Join Inner : :- SubqueryAlias b @@ -268,7 +268,7 @@ WithCTE : : +- LocalRelation [emp_name#x, bonus_amt#x] : +- SubqueryAlias e : +- SubqueryAlias emp_cte - : +- CTERelationRef xxxx, true, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] + : +- CTERelationRef xxxx, true, [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x], false +- SubqueryAlias d +- SubqueryAlias dept +- View (`DEPT`, [dept_id#x,dept_name#x,state#x]) @@ -322,7 +322,7 @@ WithCTE : +- Filter (count(1)#xL > cast(1 as bigint)) : +- Aggregate [dept_id#x], [dept_id#x, max(salary#x) AS max(salary)#x, count(1) AS count(1)#xL] : +- SubqueryAlias empdept - : +- CTERelationRef xxxx, true, [id#x, salary#x, emp_name#x, dept_id#x] + : +- CTERelationRef xxxx, true, [id#x, salary#x, emp_name#x, dept_id#x], false +- SubqueryAlias bonus +- View (`BONUS`, [emp_name#x,bonus_amt#x]) +- Project [cast(emp_name#x as string) AS emp_name#x, cast(bonus_amt#x as double) AS bonus_amt#x] @@ -375,7 +375,7 @@ WithCTE : +- Filter (count(1)#xL < cast(1 as bigint)) : +- Aggregate [dept_id#x], [dept_id#x, max(salary#x) AS max(salary)#x, count(1) AS count(1)#xL] : +- SubqueryAlias empdept - : +- CTERelationRef xxxx, true, [id#x, salary#x, emp_name#x, dept_id#x] + : +- CTERelationRef xxxx, true, [id#x, salary#x, emp_name#x, dept_id#x], false +- SubqueryAlias bonus +- View (`BONUS`, [emp_name#x,bonus_amt#x]) +- Project [cast(emp_name#x as string) AS emp_name#x, cast(bonus_amt#x as double) AS bonus_amt#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-multiple-columns.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-multiple-columns.sql.out index ab16f4b9d687c..1717e553f5c3c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-multiple-columns.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-multiple-columns.sql.out @@ -330,7 +330,7 @@ WithCTE +- Project [t1a#x, t1b#x, t1a#x, t1b#x] +- Join Inner, (t1b#x = t1b#x) :- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias cte2 +- SubqueryAlias cte1 - +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-with-cte.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-with-cte.sql.out index 9d82c707177b7..6d0a944bfcfe2 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-with-cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-with-cte.sql.out @@ -138,7 +138,7 @@ WithCTE : +- Project [t1b#x] : +- Filter (cast(t1b#x as int) > 0) : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias t1 +- View (`t1`, [t1a#x,t1b#x,t1c#x,t1d#xL,t1e#x,t1f#x,t1g#x,t1h#x,t1i#x]) +- Project [cast(t1a#x as string) AS t1a#x, cast(t1b#x as smallint) AS t1b#x, cast(t1c#x as int) AS t1c#x, cast(t1d#xL as bigint) AS t1d#xL, cast(t1e#x as float) AS t1e#x, cast(t1f#x as double) AS t1f#x, cast(t1g#x as double) AS t1g#x, cast(t1h#x as timestamp) AS t1h#x, cast(t1i#x as date) AS t1i#x] @@ -197,21 +197,21 @@ WithCTE : : : :- Project [t1b#x] : : : : +- Filter (cast(t1b#x as int) > 0) : : : : +- SubqueryAlias cte1 - : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false : : : +- Project [t1b#x] : : : +- Filter (cast(t1b#x as int) > 5) : : : +- SubqueryAlias cte1 - : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false : : +- Intersect false : : :- Project [t1b#x] : : : +- SubqueryAlias cte1 - : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false : : +- Project [t1b#x] : : +- SubqueryAlias cte1 - : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false : +- Project [t1b#x] : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias t1 +- View (`t1`, [t1a#x,t1b#x,t1c#x,t1d#xL,t1e#x,t1f#x,t1g#x,t1h#x,t1i#x]) +- Project [cast(t1a#x as string) AS t1a#x, cast(t1b#x as smallint) AS t1b#x, cast(t1c#x as int) AS t1c#x, cast(t1d#xL as bigint) AS t1d#xL, cast(t1e#x as float) AS t1e#x, cast(t1f#x as double) AS t1f#x, cast(t1g#x as double) AS t1g#x, cast(t1h#x as timestamp) AS t1h#x, cast(t1i#x as date) AS t1i#x] @@ -268,22 +268,22 @@ WithCTE : : : :- Join FullOuter, (t1c#x = t1c#x) : : : : :- Join Inner, (t1b#x > t1b#x) : : : : : :- SubqueryAlias cte1 - : : : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x] + : : : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x], false : : : : : +- SubqueryAlias cte2 : : : : : +- SubqueryAlias cte1 - : : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x] + : : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x], false : : : : +- SubqueryAlias cte3 : : : : +- SubqueryAlias cte1 - : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x] + : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x], false : : : +- SubqueryAlias cte4 : : : +- SubqueryAlias cte1 - : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x] + : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x], false : : +- SubqueryAlias cte5 : : +- SubqueryAlias cte1 - : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x] + : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x], false : +- SubqueryAlias cte6 : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x], false +- SubqueryAlias t1 +- View (`t1`, [t1a#x,t1b#x,t1c#x,t1d#xL,t1e#x,t1f#x,t1g#x,t1h#x,t1i#x]) +- Project [cast(t1a#x as string) AS t1a#x, cast(t1b#x as smallint) AS t1b#x, cast(t1c#x as int) AS t1c#x, cast(t1d#xL as bigint) AS t1d#xL, cast(t1e#x as float) AS t1e#x, cast(t1f#x as double) AS t1f#x, cast(t1g#x as double) AS t1g#x, cast(t1h#x as timestamp) AS t1h#x, cast(t1i#x as date) AS t1i#x] @@ -354,16 +354,16 @@ WithCTE :- Join FullOuter, (t1a#x = t1a#x) : :- Join Inner, ((cast(t1b#x as int) > 5) AND (t1a#x = t1a#x)) : : :- SubqueryAlias cte1 - : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false : : +- SubqueryAlias cte2 : : +- SubqueryAlias cte1 - : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false : +- SubqueryAlias cte3 : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias cte4 +- SubqueryAlias cte1 - +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false -- !query @@ -424,10 +424,10 @@ WithCTE +- Project [t1a#x, t1b#x] +- Join Inner, (t1h#x >= t1h#x) :- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1h#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1h#x], false +- SubqueryAlias cte2 +- SubqueryAlias cte1 - +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1h#x] + +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1h#x], false -- !query @@ -485,16 +485,16 @@ WithCTE :- Join RightOuter, (t1b#x = t1b#x) : :- Join Inner, (t1a#x = t1a#x) : : :- SubqueryAlias cte1 - : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x] + : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x], false : : +- SubqueryAlias cte2 : : +- SubqueryAlias cte1 - : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x] + : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x], false : +- SubqueryAlias cte3 : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x], false +- SubqueryAlias cte4 +- SubqueryAlias cte1 - +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x] + +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x], false -- !query @@ -538,10 +538,10 @@ WithCTE +- Project [t1a#x, t1b#x] +- Join RightOuter, (t1a#x = t1a#x) :- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias cte2 +- SubqueryAlias cte1 - +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false -- !query @@ -599,15 +599,15 @@ WithCTE : : +- SubqueryAlias t1 : : +- LocalRelation [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x, t1f#x, t1g#x, t1h#x, t1i#x] : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias s +- Project [t1b#x] +- Join LeftOuter, (t1b#x = t1b#x) :- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias cte2 +- SubqueryAlias cte1 - +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false -- !query @@ -642,7 +642,7 @@ WithCTE : +- Project [t1b#x] : +- Filter (cast(t1b#x as int) < 0) : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x], false +- SubqueryAlias t1 +- View (`t1`, [t1a#x,t1b#x,t1c#x,t1d#xL,t1e#x,t1f#x,t1g#x,t1h#x,t1i#x]) +- Project [cast(t1a#x as string) AS t1a#x, cast(t1b#x as smallint) AS t1b#x, cast(t1c#x as int) AS t1c#x, cast(t1d#xL as bigint) AS t1d#xL, cast(t1e#x as float) AS t1e#x, cast(t1f#x as double) AS t1f#x, cast(t1g#x as double) AS t1g#x, cast(t1h#x as timestamp) AS t1h#x, cast(t1i#x as date) AS t1i#x] @@ -722,16 +722,16 @@ WithCTE : :- Join RightOuter, (t1b#x = t1b#x) : : :- Join Inner, (t1a#x = t1a#x) : : : :- SubqueryAlias cte1 - : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x] + : : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x], false : : : +- SubqueryAlias cte2 : : : +- SubqueryAlias cte1 - : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x] + : : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x], false : : +- SubqueryAlias cte3 : : +- SubqueryAlias cte1 - : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x] + : : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x], false : +- SubqueryAlias cte4 : +- SubqueryAlias cte1 - : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x] + : +- CTERelationRef xxxx, true, [t1a#x, t1b#x, t1c#x, t1d#xL, t1h#x], false +- SubqueryAlias t1 +- View (`t1`, [t1a#x,t1b#x,t1c#x,t1d#xL,t1e#x,t1f#x,t1g#x,t1h#x,t1i#x]) +- Project [cast(t1a#x as string) AS t1a#x, cast(t1b#x as smallint) AS t1b#x, cast(t1c#x as int) AS t1c#x, cast(t1d#xL as bigint) AS t1d#xL, cast(t1e#x as float) AS t1e#x, cast(t1f#x as double) AS t1f#x, cast(t1g#x as double) AS t1g#x, cast(t1h#x as timestamp) AS t1h#x, cast(t1i#x as date) AS t1i#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out index cb41f7cdc4557..c7271d8b85628 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-select.sql.out @@ -623,7 +623,7 @@ Project [c1#x, scalar-subquery#x [c1#x] AS scalarsubquery(c1)#x] : : +- OneRowRelation : +- Project [(a#x + outer(c1#x)) AS (a + outer(t1.c1))#x] : +- SubqueryAlias t -: +- CTERelationRef xxxx, true, [a#x] +: +- CTERelationRef xxxx, true, [a#x], false +- SubqueryAlias t1 +- View (`t1`, [c1#x,c2#x]) +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] @@ -647,7 +647,7 @@ Project [c1#x, scalar-subquery#x [c1#x] AS scalarsubquery(c1)#xL] : : +- LocalRelation [c1#x, c2#x] : +- Aggregate [sum(c2#x) AS sum(c2)#xL] : +- SubqueryAlias t -: +- CTERelationRef xxxx, true, [c1#x, c2#x] +: +- CTERelationRef xxxx, true, [c1#x, c2#x], false +- SubqueryAlias t1 +- View (`t1`, [c1#x,c2#x]) +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] @@ -677,10 +677,10 @@ Project [c1#x, scalar-subquery#x [c1#x] AS scalarsubquery(c1)#xL] : : +- Project [c1#x, c2#x] : : +- Filter (outer(c1#x) = c1#x) : : +- SubqueryAlias t3 -: : +- CTERelationRef xxxx, true, [c1#x, c2#x] +: : +- CTERelationRef xxxx, true, [c1#x, c2#x], false : +- Aggregate [sum(c2#x) AS sum(c2)#xL] : +- SubqueryAlias t4 -: +- CTERelationRef xxxx, true, [c1#x, c2#x] +: +- CTERelationRef xxxx, true, [c1#x, c2#x], false +- SubqueryAlias t1 +- View (`t1`, [c1#x,c2#x]) +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] @@ -713,10 +713,10 @@ Project [c1#x, scalar-subquery#x [c1#x] AS scalarsubquery(c1)#xL] : +- Union false, false : :- Project [c1#x, c2#x] : : +- SubqueryAlias t -: : +- CTERelationRef xxxx, true, [c1#x, c2#x] +: : +- CTERelationRef xxxx, true, [c1#x, c2#x], false : +- Project [c2#x, c1#x] : +- SubqueryAlias t -: +- CTERelationRef xxxx, true, [c1#x, c2#x] +: +- CTERelationRef xxxx, true, [c1#x, c2#x], false +- SubqueryAlias t1 +- View (`t1`, [c1#x,c2#x]) +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] @@ -756,9 +756,9 @@ WithCTE : : +- Aggregate [sum(c2#x) AS sum(c2)#xL] : : +- Filter (c1#x = outer(c1#x)) : : +- SubqueryAlias t - : : +- CTERelationRef xxxx, true, [c1#x, c2#x] + : : +- CTERelationRef xxxx, true, [c1#x, c2#x], false : +- SubqueryAlias v - : +- CTERelationRef xxxx, true, [c1#x, c2#x] + : +- CTERelationRef xxxx, true, [c1#x, c2#x], false +- SubqueryAlias t1 +- View (`t1`, [c1#x,c2#x]) +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] @@ -779,7 +779,7 @@ WithCTE : +- Project [a#x] : +- Filter (a#x = outer(c1#x)) : +- SubqueryAlias t - : +- CTERelationRef xxxx, true, [a#x] + : +- CTERelationRef xxxx, true, [a#x], false +- SubqueryAlias t1 +- View (`t1`, [c1#x,c2#x]) +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] @@ -1027,7 +1027,7 @@ WithCTE : +- Aggregate [sum(1) AS sum(1)#xL] : +- Filter ((a#x = cast(outer(col#x) as int)) OR (upper(cast(outer(col#x) as string)) = Y)) : +- SubqueryAlias T - : +- CTERelationRef xxxx, true, [a#x] + : +- CTERelationRef xxxx, true, [a#x], false +- SubqueryAlias foo +- Project [null AS col#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/transform.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/transform.sql.out index cda76f716a8a8..aa595c551f792 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/transform.sql.out @@ -888,10 +888,10 @@ WithCTE +- Join Inner, (b#x = b#x) :- SubqueryAlias t1 : +- SubqueryAlias temp - : +- CTERelationRef xxxx, true, [b#x] + : +- CTERelationRef xxxx, true, [b#x], false +- SubqueryAlias t2 +- SubqueryAlias temp - +- CTERelationRef xxxx, true, [b#x] + +- CTERelationRef xxxx, true, [b#x], false -- !query @@ -1035,3 +1035,14 @@ ScriptTransformation cat, [a#x, b#x], ScriptInputOutputSchema(List(),List(),None +- Project [a#x, b#x] +- SubqueryAlias complex_trans +- LocalRelation [a#x, b#x] + + +-- !query +SELECT TRANSFORM (a, b) + USING 'cat' AS (a CHAR(10), b VARCHAR(10)) +FROM VALUES('apache', 'spark') t(a, b) +-- !query analysis +ScriptTransformation cat, [a#x, b#x], ScriptInputOutputSchema(List(),List(),None,None,List(),List(),None,None,false) ++- Project [a#x, b#x] + +- SubqueryAlias t + +- LocalRelation [a#x, b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out index bbc07c22805a6..ceda149c48434 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out @@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x] +- OneRowRelation +-- !query +SELECT try_add(2147483647, decimal(1)) +-- !query analysis +Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x] ++- OneRowRelation + + +-- !query +SELECT try_add(2147483647, "1") +-- !query analysis +Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x] ++- OneRowRelation + + -- !query SELECT try_add(-2147483648, -1) -- !query analysis @@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x] +- OneRowRelation +-- !query +SELECT try_divide(1, decimal(0)) +-- !query analysis +Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x] ++- OneRowRelation + + +-- !query +SELECT try_divide(1, "0") +-- !query analysis +Project [try_divide(1, 0) AS try_divide(1, 0)#x] ++- OneRowRelation + + -- !query SELECT try_divide(interval 2 year, 2) -- !query analysis @@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x] +- OneRowRelation +-- !query +SELECT try_subtract(2147483647, decimal(-1)) +-- !query analysis +Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x] ++- OneRowRelation + + +-- !query +SELECT try_subtract(2147483647, "-1") +-- !query analysis +Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x] ++- OneRowRelation + + -- !query SELECT try_subtract(-2147483648, 1) -- !query analysis @@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x] +- OneRowRelation +-- !query +SELECT try_multiply(2147483647, decimal(-2)) +-- !query analysis +Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x] ++- OneRowRelation + + +-- !query +SELECT try_multiply(2147483647, "-2") +-- !query analysis +Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x] ++- OneRowRelation + + -- !query SELECT try_multiply(-2147483648, 2) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/using-join.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/using-join.sql.out index 0fe7254d7348c..97410d3cdd369 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/using-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/using-join.sql.out @@ -833,6 +833,6 @@ WithCTE +- Project [coalesce(key#x, key#x) AS key#x, key#x, key#x, key#x] +- Join FullOuter, (key#x = key#x) :- SubqueryAlias t1 - : +- CTERelationRef xxxx, true, [key#x] + : +- CTERelationRef xxxx, true, [key#x], false +- SubqueryAlias t2 - +- CTERelationRef xxxx, true, [key#x] + +- CTERelationRef xxxx, true, [key#x], false diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 48edc6b474254..865dc8bac4ea5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -141,6 +141,7 @@ select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4); select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT)); select array_insert(array(2, 3, NULL, 4), 5, 5); select array_insert(array(2, 3, NULL, 4), -5, 1); +select array_insert(array(1), 2, cast(2 as tinyint)); set spark.sql.legacy.negativeIndexInArrayInsert=true; select array_insert(array(1, 3, 4), -2, 2); @@ -176,3 +177,7 @@ select array_prepend(CAST(null AS ARRAY), CAST(null as String)); select array_prepend(array(), 1); select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- SPARK-45599: Confirm 0.0, -0.0, and NaN are handled appropriately. +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))); +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))); diff --git a/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql b/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql index 8117dec53f4ab..be038e1083cd8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql @@ -49,6 +49,8 @@ desc formatted char_tbl1; create table char_part(c1 char(5), c2 char(2), v1 varchar(6), v2 varchar(2)) using parquet partitioned by (v2, c2); desc formatted char_part; +alter table char_part change column c1 comment 'char comment'; +alter table char_part change column v1 comment 'varchar comment'; alter table char_part add partition (v2='ke', c2='nt') location 'loc1'; desc formatted char_part; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c35cdb0de2719..ce1b422de3196 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -264,3 +264,18 @@ FROM ( GROUP BY b ) t3 GROUP BY c; + +-- SPARK-45599: Check that "weird" doubles group and sort as desired. +SELECT col1, count(*) AS cnt +FROM VALUES + (0.0), + (-0.0), + (double('NaN')), + (double('NaN')), + (double('Infinity')), + (double('Infinity')), + (-double('Infinity')), + (-double('Infinity')) +GROUP BY col1 +ORDER BY col1 +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 7925a21de04cd..37081de012e98 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -11,6 +11,8 @@ create or replace temporary view nested as values -- Only allow lambda's in higher order functions. select upper(x -> x) as v; +-- Also test functions registered with `ExpressionBuilder`. +select ceil(x -> x) as v; -- Identity transform an array select transform(zs, z -> z) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql index a1bd500455de9..e85fdf7b5da3d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql @@ -36,7 +36,7 @@ DROP SCHEMA s; -- Function reference SELECT IDENTIFIER('COAL' || 'ESCE')(NULL, 1); -SELECT IDENTIFIER('abs')(-1); +SELECT IDENTIFIER('abs')(c1) FROM VALUES(-1) AS T(c1); SELECT * FROM IDENTIFIER('ra' || 'nge')(0, 1); -- Table DDL @@ -107,9 +107,10 @@ SELECT IDENTIFIER('') FROM VALUES(1) AS T(``); VALUES(IDENTIFIER(CAST(NULL AS STRING))); VALUES(IDENTIFIER(1)); VALUES(IDENTIFIER(SUBSTR('HELLO', 1, RAND() + 1))); +SELECT `IDENTIFIER`('abs')(c1) FROM VALUES(-1) AS T(c1); -CREATE TABLE IDENTIFIER(1)(c1 INT); -CREATE TABLE IDENTIFIER('a.b.c')(c1 INT); +CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv; +CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) USING csv; CREATE VIEW IDENTIFIER('a.b.c')(c1) AS VALUES(1); DROP TABLE IDENTIFIER('a.b.c'); DROP VIEW IDENTIFIER('a.b.c'); @@ -121,6 +122,15 @@ CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.a DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg'); CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1); +-- SPARK-48273: Aggregation operation in statements using identifier clause for table name +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +create table identifier('t2') using csv as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1; +drop view v1; +drop table t1; +drop table t2; + -- Not supported SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1); SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1')); diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index 6867248f5765d..8f65dc77c960a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -60,3 +60,9 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991- select * from values (try_add(5, 0)); select * from values (try_divide(5, 0)); select * from values (10 + try_divide(5, 0)); + +-- now() should be kept as tempResolved inline expression. +select count(distinct ct) from values now(), now(), now() as data(ct); + +-- current_timestamp() should be kept as tempResolved inline expression. +select count(distinct ct) from values current_timestamp(), current_timestamp() as data(ct); diff --git a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql index c7cb5bf1117a7..df286d2a9b0a9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql @@ -1,6 +1,6 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x); -- SPARK-37613: Support ANSI Aggregate Function: regr_count diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 9f0eefc16a8cd..e1e4a370bffdc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -118,3 +118,6 @@ select +X'1'; select -date '1999-01-01'; select -timestamp '1999-01-01'; select -x'2379ACFe'; + +-- normalize -0 and -0.0 +select -0, -0.0; diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 922a1d8177780..8570496d439e6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -415,4 +415,8 @@ FROM ( ORDER BY a ) map_output SELECT TRANSFORM(a, b) - USING 'cat' AS (a, b); \ No newline at end of file + USING 'cat' AS (a, b); + +SELECT TRANSFORM (a, b) + USING 'cat' AS (a CHAR(10), b VARCHAR(10)) +FROM VALUES('apache', 'spark') t(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql index 55907b6701e50..943865b68d39e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql @@ -1,6 +1,8 @@ -- Numeric + Numeric SELECT try_add(1, 1); SELECT try_add(2147483647, 1); +SELECT try_add(2147483647, decimal(1)); +SELECT try_add(2147483647, "1"); SELECT try_add(-2147483648, -1); SELECT try_add(9223372036854775807L, 1); SELECT try_add(-9223372036854775808L, -1); @@ -38,6 +40,8 @@ SELECT try_divide(0, 0); SELECT try_divide(1, (2147483647 + 1)); SELECT try_divide(1L, (9223372036854775807L + 1L)); SELECT try_divide(1, 1.0 / 0.0); +SELECT try_divide(1, decimal(0)); +SELECT try_divide(1, "0"); -- Interval / Numeric SELECT try_divide(interval 2 year, 2); @@ -50,6 +54,8 @@ SELECT try_divide(interval 106751991 day, 0.5); -- Numeric - Numeric SELECT try_subtract(1, 1); SELECT try_subtract(2147483647, -1); +SELECT try_subtract(2147483647, decimal(-1)); +SELECT try_subtract(2147483647, "-1"); SELECT try_subtract(-2147483648, 1); SELECT try_subtract(9223372036854775807L, -1); SELECT try_subtract(-9223372036854775808L, 1); @@ -66,6 +72,8 @@ SELECT try_subtract(interval 106751991 day, interval -3 day); -- Numeric * Numeric SELECT try_multiply(2, 3); SELECT try_multiply(2147483647, -2); +SELECT try_multiply(2147483647, decimal(-2)); +SELECT try_multiply(2147483647, "-2"); SELECT try_multiply(-2147483648, 2); SELECT try_multiply(9223372036854775807L, 2); SELECT try_multiply(-9223372036854775808L, -2); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 03be0f9d84b1b..6a07d659e39b5 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -659,6 +659,14 @@ struct> [1,2,3,null,4] +-- !query +select array_insert(array(1), 2, cast(2 as tinyint)) +-- !query schema +struct> +-- !query output +[1,2] + + -- !query set spark.sql.legacy.negativeIndexInArrayInsert=true -- !query schema @@ -899,3 +907,19 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out index 699c916fd8fdb..9593291fae21d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out @@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1000000000000000000000000000000000000.00000000000000000000000000000000000000" + "value" : "1000000000000000000000000000000000000.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123456.00000000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123456.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901234.56000000000000000000000000000000000000" + "value" : "101234567890123456789012345678901234.560000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123.45600000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123.456000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1012345678901234567890123456789012.34560000000000000000000000000000000000" + "value" : "1012345678901234567890123456789012.345600000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out index e479b49463e74..dceb370c83884 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out @@ -32,6 +32,27 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 9eb4a4766df89..b0d128e967a6d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -2355,7 +2355,7 @@ org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "_LEGACY_ERROR_TEMP_0063", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval year to month: -\t2-2\t" + "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out index 85bcc2713ff5c..452580e4f3c34 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out @@ -770,3 +770,11 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query schema +struct<0:int,0.0:decimal(1,1)> +-- !query output +0 0.0 diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out index 414198b19645d..bb630243ee1ae 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out @@ -15,6 +15,22 @@ struct NULL +-- !query +SELECT try_add(2147483647, decimal(1)) +-- !query schema +struct +-- !query output +2147483648 + + +-- !query +SELECT try_add(2147483647, "1") +-- !query schema +struct +-- !query output +2147483648 + + -- !query SELECT try_add(-2147483648, -1) -- !query schema @@ -341,6 +357,22 @@ org.apache.spark.SparkArithmeticException } +-- !query +SELECT try_divide(1, decimal(0)) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_divide(1, "0") +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema @@ -405,6 +437,22 @@ struct NULL +-- !query +SELECT try_subtract(2147483647, decimal(-1)) +-- !query schema +struct +-- !query output +2147483648 + + +-- !query +SELECT try_subtract(2147483647, "-1") +-- !query schema +struct +-- !query output +2147483648 + + -- !query SELECT try_subtract(-2147483648, 1) -- !query schema @@ -547,6 +595,22 @@ struct NULL +-- !query +SELECT try_multiply(2147483647, decimal(-2)) +-- !query schema +struct +-- !query output +-4294967294 + + +-- !query +SELECT try_multiply(2147483647, "-2") +-- !query schema +struct +-- !query output +-4294967294 + + -- !query SELECT try_multiply(-2147483648, 2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 9dbf4fbebc20b..d33fc62f0d9a1 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -540,6 +540,14 @@ struct> [1,2,3,null,4] +-- !query +select array_insert(array(1), 2, cast(2 as tinyint)) +-- !query schema +struct> +-- !query output +[1,2] + + -- !query set spark.sql.legacy.negativeIndexInArrayInsert=true -- !query schema @@ -780,3 +788,19 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] + + +-- !query +select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) +-- !query schema +struct> +-- !query output +[0.0,NaN] diff --git a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out index 888e8a9428910..dd8bdc698ea7f 100644 --- a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out @@ -543,6 +543,22 @@ Location [not included in comparison]/{warehouse_dir}/char_part Partition Provider Catalog +-- !query +alter table char_part change column c1 comment 'char comment' +-- !query schema +struct<> +-- !query output + + + +-- !query +alter table char_part change column v1 comment 'varchar comment' +-- !query schema +struct<> +-- !query output + + + -- !query alter table char_part add partition (v2='ke', c2='nt') location 'loc1' -- !query schema @@ -556,8 +572,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -599,8 +615,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -634,8 +650,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -669,8 +685,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 3c2677c936f9c..54fa9ca418cc1 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -1081,7 +1081,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 struct -- !query output == Parsed Logical Plan == -'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false +'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false +- 'Project [*] +- 'UnresolvedRelation [explain_temp4], [], false diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index f54c6c5e44f2e..20314b5f9b93a 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1023,7 +1023,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 struct -- !query output == Parsed Logical Plan == -'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false +'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false +- 'Project [*] +- 'UnresolvedRelation [explain_temp4], [], false diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index db79646fe435a..548917ef79b2d 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1121,3 +1121,25 @@ struct -- !query output 0 2 + + +-- !query +SELECT col1, count(*) AS cnt +FROM VALUES + (0.0), + (-0.0), + (double('NaN')), + (double('NaN')), + (double('Infinity')), + (double('Infinity')), + (-double('Infinity')), + (-double('Infinity')) +GROUP BY col1 +ORDER BY col1 +-- !query schema +struct +-- !query output +-Infinity 2 +0.0 2 +Infinity 2 +NaN 2 diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index e479b49463e74..dceb370c83884 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -32,6 +32,27 @@ org.apache.spark.sql.AnalysisException } +-- !query +select ceil(x -> x) as v +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + "messageParameters" : { + "class" : "org.apache.spark.sql.catalyst.expressions.CeilExpressionBuilder$" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "ceil(x -> x)" + } ] +} + + -- !query select transform(zs, z -> z) as v from nested -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out index 8eabb74da97ba..62f43152c48d9 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out @@ -205,9 +205,9 @@ struct -- !query -SELECT IDENTIFIER('abs')(-1) +SELECT IDENTIFIER('abs')(c1) FROM VALUES(-1) AS T(c1) -- !query schema -struct +struct -- !query output 1 @@ -771,7 +771,30 @@ org.apache.spark.sql.AnalysisException -- !query -CREATE TABLE IDENTIFIER(1)(c1 INT) +SELECT `IDENTIFIER`('abs')(c1) FROM VALUES(-1) AS T(c1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_ROUTINE", + "sqlState" : "42883", + "messageParameters" : { + "routineName" : "`IDENTIFIER`", + "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "`IDENTIFIER`('abs')" + } ] +} + + +-- !query +CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv -- !query schema struct<> -- !query output @@ -795,7 +818,7 @@ org.apache.spark.sql.AnalysisException -- !query -CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) +CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) USING csv -- !query schema struct<> -- !query output @@ -987,6 +1010,62 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table identifier('t2') using csv as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop view v1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t2 +-- !query schema +struct<> +-- !query output + + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 709d7ab73f6c4..b6c90b95c1d34 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -266,3 +266,19 @@ select * from values (10 + try_divide(5, 0)) struct -- !query output NULL + + +-- !query +select count(distinct ct) from values now(), now(), now() as data(ct) +-- !query schema +struct +-- !query output +1 + + +-- !query +select count(distinct ct) from values current_timestamp(), current_timestamp() as data(ct) +-- !query schema +struct +-- !query output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index fe15ade941785..faba4abfdbe7d 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -2168,7 +2168,7 @@ org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "_LEGACY_ERROR_TEMP_0063", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval year to month: -\t2-2\t" + "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out index 1379713a9fb0d..e511ea75aae5a 100644 --- a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out @@ -1,7 +1,7 @@ -- Automatically generated by SQLQueryTestSuite -- !query CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x) -- !query schema struct<> @@ -31,7 +31,7 @@ SELECT k, count(*), regr_count(y, x) FROM testRegression GROUP BY k struct -- !query output 1 1 0 -2 4 3 +2 5 3 -- !query @@ -40,7 +40,7 @@ SELECT k, count(*) FILTER (WHERE x IS NOT NULL), regr_count(y, x) FROM testRegre struct -- !query output 1 0 0 -2 3 3 +2 4 3 -- !query @@ -99,7 +99,7 @@ SELECT k, avg(x), avg(y), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression G struct -- !query output 1 NULL 10.0 NULL NULL -2 22.666666666666668 21.25 22.666666666666668 20.0 +2 27.0 21.25 22.666666666666668 20.0 -- !query @@ -116,7 +116,7 @@ SELECT regr_sxx(y, x) FROM testRegression -- !query schema struct -- !query output -288.66666666666663 +288.6666666666667 -- !query @@ -124,7 +124,7 @@ SELECT regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL -- !query schema struct -- !query output -288.66666666666663 +288.6666666666667 -- !query @@ -133,7 +133,7 @@ SELECT k, regr_sxx(y, x) FROM testRegression GROUP BY k struct -- !query output 1 NULL -2 288.66666666666663 +2 288.6666666666667 -- !query @@ -141,7 +141,7 @@ SELECT k, regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NU -- !query schema struct -- !query output -2 288.66666666666663 +2 288.6666666666667 -- !query @@ -215,7 +215,7 @@ SELECT regr_slope(y, x) FROM testRegression -- !query schema struct -- !query output -0.8314087759815244 +0.8314087759815242 -- !query @@ -223,7 +223,7 @@ SELECT regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NUL -- !query schema struct -- !query output -0.8314087759815244 +0.8314087759815242 -- !query @@ -232,7 +232,7 @@ SELECT k, regr_slope(y, x) FROM testRegression GROUP BY k struct -- !query output 1 NULL -2 0.8314087759815244 +2 0.8314087759815242 -- !query @@ -240,7 +240,7 @@ SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT -- !query schema struct -- !query output -2 0.8314087759815244 +2 0.8314087759815242 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 85bcc2713ff5c..452580e4f3c34 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -770,3 +770,11 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "-x'2379ACFe'" } ] } + + +-- !query +select -0, -0.0 +-- !query schema +struct<0:int,0.0:decimal(1,1)> +-- !query output +0 0.0 diff --git a/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out b/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out index dcb96b9d2dce6..e1f4e3068b458 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out @@ -78,7 +78,7 @@ CREATE TABLE spark_catalog.default.tbl ( b STRING, c INT) USING parquet -LOCATION 'file:///path/to/table' +LOCATION 'file:/path/to/table' -- !query @@ -108,7 +108,7 @@ CREATE TABLE spark_catalog.default.tbl ( b STRING, c INT) USING parquet -LOCATION 'file:///path/to/table' +LOCATION 'file:/path/to/table' -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index ab726b93c07c8..7975392fd0147 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -837,3 +837,13 @@ struct 3 3 3 3 3 3 + + +-- !query +SELECT TRANSFORM (a, b) + USING 'cat' AS (a CHAR(10), b VARCHAR(10)) +FROM VALUES('apache', 'spark') t(a, b) +-- !query schema +struct +-- !query output +apache spark diff --git a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out index c706a26078926..76f1d89b20927 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out @@ -15,6 +15,22 @@ struct NULL +-- !query +SELECT try_add(2147483647, decimal(1)) +-- !query schema +struct +-- !query output +2147483648 + + +-- !query +SELECT try_add(2147483647, "1") +-- !query schema +struct +-- !query output +2.147483648E9 + + -- !query SELECT try_add(-2147483648, -1) -- !query schema @@ -249,6 +265,22 @@ struct NULL +-- !query +SELECT try_divide(1, decimal(0)) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_divide(1, "0") +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema @@ -313,6 +345,22 @@ struct NULL +-- !query +SELECT try_subtract(2147483647, decimal(-1)) +-- !query schema +struct +-- !query output +2147483648 + + +-- !query +SELECT try_subtract(2147483647, "-1") +-- !query schema +struct +-- !query output +2.147483648E9 + + -- !query SELECT try_subtract(-2147483648, 1) -- !query schema @@ -409,6 +457,22 @@ struct NULL +-- !query +SELECT try_multiply(2147483647, decimal(-2)) +-- !query schema +struct +-- !query output +-4294967294 + + +-- !query +SELECT try_multiply(2147483647, "-2") +-- !query schema +struct +-- !query output +-4.294967294E9 + + -- !query SELECT try_multiply(-2147483648, 2) -- !query schema diff --git a/sql/core/src/test/resources/test-data/char.csv b/sql/core/src/test/resources/test-data/char.csv new file mode 100644 index 0000000000000..d2be68a15fc12 --- /dev/null +++ b/sql/core/src/test/resources/test-data/char.csv @@ -0,0 +1,4 @@ +color,name +pink,Bob +blue,Mike +grey,Tom diff --git a/sql/core/src/test/resources/test-data/test-archive.har/_index b/sql/core/src/test/resources/test-data/test-archive.har/_index new file mode 100644 index 0000000000000..b7ae3ef9c5a4c --- /dev/null +++ b/sql/core/src/test/resources/test-data/test-archive.har/_index @@ -0,0 +1,2 @@ +%2F dir 1707380620211+493+tigrulya+hadoop 0 0 test.csv +%2Ftest.csv file part-0 0 6 1707380620197+420+tigrulya+hadoop diff --git a/sql/core/src/test/resources/test-data/test-archive.har/_masterindex b/sql/core/src/test/resources/test-data/test-archive.har/_masterindex new file mode 100644 index 0000000000000..4192a9597299b --- /dev/null +++ b/sql/core/src/test/resources/test-data/test-archive.har/_masterindex @@ -0,0 +1,2 @@ +3 +0 1948547033 0 119 diff --git a/sql/core/src/test/resources/test-data/test-archive.har/part-0 b/sql/core/src/test/resources/test-data/test-archive.har/part-0 new file mode 100644 index 0000000000000..01e79c32a8c99 --- /dev/null +++ b/sql/core/src/test/resources/test-data/test-archive.har/part-0 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 5f6c44792658a..73f5b742715eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -678,6 +678,27 @@ abstract class CTEInlineSuiteBase }.isDefined, "CTE columns should not be pruned.") } } + + test("SPARK-49816: should only update out-going-ref-count for referenced outer CTE relation") { + withView("v") { + sql( + """ + |WITH + |t1 AS (SELECT 1 col), + |t2 AS (SELECT * FROM t1) + |SELECT * FROM t2 + |""".stripMargin).createTempView("v") + // r1 is un-referenced, but it should not decrease the ref count of t2 inside view v. + val df = sql( + """ + |WITH + |r1 AS (SELECT * FROM v), + |r2 AS (SELECT * FROM v) + |SELECT * FROM r2 + |""".stripMargin) + checkAnswer(df, Row(1)) + } + } } class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e4a67347f5b1..9815cb816c994 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.CleanerListener import org.apache.spark.executor.DataReadMethod._ import org.apache.spark.executor.DataReadMethod.DataReadMethod -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} @@ -1623,23 +1624,44 @@ class CachedTableSuite extends QueryTest with SQLTestUtils SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - withTempView("t1", "t2", "t3") { - withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "false") { - sql("CACHE TABLE t1 as SELECT /*+ REPARTITION */ * FROM values(1) as t(c)") - assert(spark.table("t1").rdd.partitions.length == 2) + var finalPlan = "" + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, physicalPlanDesc, sparkPlanInfo) => + if (sparkPlanInfo.simpleString.startsWith( + "AdaptiveSparkPlan isFinalPlan=true")) { + finalPlan = physicalPlanDesc + } + case _ => // ignore other events + } } + } - withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") { - assert(spark.table("t1").rdd.partitions.length == 2) - sql("CACHE TABLE t2 as SELECT /*+ REPARTITION */ * FROM values(2) as t(c)") - assert(spark.table("t2").rdd.partitions.length == 1) - } + withTempView("t0", "t1", "t2") { + try { + spark.range(10).write.saveAsTable("t0") + spark.sparkContext.listenerBus.waitUntilEmpty() + spark.sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "false") { - assert(spark.table("t1").rdd.partitions.length == 2) - assert(spark.table("t2").rdd.partitions.length == 1) - sql("CACHE TABLE t3 as SELECT /*+ REPARTITION */ * FROM values(3) as t(c)") - assert(spark.table("t3").rdd.partitions.length == 2) + withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "false") { + sql("CACHE TABLE t1 as SELECT /*+ REPARTITION */ * FROM (" + + "SELECT distinct (id+1) FROM t0)") + assert(spark.table("t1").rdd.partitions.length == 2) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(finalPlan.nonEmpty && !finalPlan.contains("coalesced")) + } + + finalPlan = "" // reset finalPlan + withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") { + sql("CACHE TABLE t2 as SELECT /*+ REPARTITION */ * FROM (" + + "SELECT distinct (id-1) FROM t0)") + assert(spark.table("t2").rdd.partitions.length == 2) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(finalPlan.nonEmpty && finalPlan.contains("coalesced")) + } + } finally { + spark.sparkContext.removeSparkListener(listener) } } } @@ -1688,4 +1710,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } } } + + test("SPARK-47633: Cache hit for lateral join with join condition") { + withTempView("t", "q1") { + sql("create or replace temp view t(c1, c2) as values (0, 1), (1, 2)") + val query = """select * + |from t + |join lateral ( + | select c1 as a, c2 as b + | from t) + |on c1 = a; + |""".stripMargin + sql(s"cache table q1 as $query") + val df = sql(query) + checkAnswer(df, + Row(0, 1, 0, 1) :: Row(1, 2, 1, 2) :: Nil) + assert(getNumInMemoryRelations(df) == 1) + } + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 4a7632486c046..c5d34a33a0abe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -1064,6 +1064,54 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa } } } + + test("SPARK-48498: always do char padding in predicates") { + import testImplicits._ + withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false") { + withTempPath { dir => + withTable("t1", "t2") { + Seq( + "12" -> "12", + "12" -> "12 ", + "12 " -> "12", + "12 " -> "12 " + ).toDF("c1", "c2").write.format(format).save(dir.toString) + + sql(s"CREATE TABLE t1 (c1 CHAR(3), c2 STRING) USING $format LOCATION '$dir'") + // Comparing CHAR column with STRING column directly compares the stored value. + checkAnswer( + sql("SELECT c1 = c2 FROM t1"), + Seq(Row(true), Row(false), Row(false), Row(true)) + ) + checkAnswer( + sql("SELECT c1 IN (c2) FROM t1"), + Seq(Row(true), Row(false), Row(false), Row(true)) + ) + // No matter the CHAR type value is padded or not in the storage, we should always pad it + // before comparison with STRING literals. + checkAnswer( + sql("SELECT c1 = '12', c1 = '12 ', c1 = '12 ' FROM t1 WHERE c2 = '12'"), + Seq(Row(true, true, true), Row(true, true, true)) + ) + checkAnswer( + sql("SELECT c1 IN ('12'), c1 IN ('12 '), c1 IN ('12 ') FROM t1 WHERE c2 = '12'"), + Seq(Row(true, true, true), Row(true, true, true)) + ) + + sql(s"CREATE TABLE t2 (c1 CHAR(3), c2 CHAR(5)) USING $format LOCATION '$dir'") + // Comparing CHAR column with CHAR column compares the padded values. + checkAnswer( + sql("SELECT c1 = c2, c2 = c1 FROM t2"), + Seq(Row(true, true), Row(true, true), Row(true, true), Row(true, true)) + ) + checkAnswer( + sql("SELECT c1 IN (c2), c2 IN (c1) FROM t2"), + Seq(Row(true, true), Row(true, true), Row(true, true), Row(true, true)) + ) + } + } + } + } } class DSV2CharVarcharTestSuite extends CharVarcharTestSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d78771a8f19bc..5a8681aed973a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1068,6 +1069,39 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("SPARK-45599: Neither 0.0 nor -0.0 should be dropped when computing percentile") { + // To reproduce the bug described in SPARK-45599, we need exactly these rows in roughly + // this order in a DataFrame with exactly 1 partition. + // scalastyle:off line.size.limit + // See: https://issues.apache.org/jira/browse/SPARK-45599?focusedCommentId=17806954&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-17806954 + // scalastyle:on line.size.limit + val spark45599Repro: DataFrame = Seq( + 0.0, + 2.0, + 153.0, + 168.0, + 3252411229536261.0, + 7.205759403792794e+16, + 1.7976931348623157e+308, + 0.25, + Double.NaN, + Double.NaN, + -0.0, + -128.0, + Double.NaN, + Double.NaN + ).toDF("val").coalesce(1) + + checkAnswer( + spark45599Repro.agg( + percentile(col("val"), lit(0.1)) + ), + // With the buggy implementation of OpenHashSet, this returns `0.050000000000000044` + // instead of `-0.0`. + List(Row(-0.0)) + ) + } + test("any_value") { checkAnswer( courseSales.groupBy("course").agg( @@ -2102,6 +2136,152 @@ class DataFrameAggregateSuite extends QueryTest Seq(Row(1)) ) } + + test("SPARK-46779: Group by subquery with a cached relation") { + withTempView("data") { + sql( + """create or replace temp view data(c1, c2) as values + |(1, 2), + |(1, 3), + |(3, 7)""".stripMargin) + sql("cache table data") + val df = sql( + """select c1, (select count(*) from data d1 where d1.c1 = d2.c1), count(c2) + |from data d2 group by all""".stripMargin) + checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil) + } + } + + test("aggregating with various distinct expressions") { + abstract class AggregateTestCaseBase( + val query: String, + val resultSeq: Seq[Seq[Row]], + val hasExpandNodeInPlan: Boolean) + case class AggregateTestCase( + override val query: String, + override val resultSeq: Seq[Seq[Row]], + override val hasExpandNodeInPlan: Boolean) + extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan) + case class AggregateTestCaseDefault( + override val query: String) + extends AggregateTestCaseBase( + query, + Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = true) + + val t = "t" + val testCases: Seq[AggregateTestCaseBase] = Seq( + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1 + 2) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t" + ), + AggregateTestCase( + s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT 1, "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT current_date()) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t""" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col", + Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0", + Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))), + hasExpandNodeInPlan = true + ) + ) + withTable(t) { + sql(s"create table $t(col int) using parquet") + Seq(0, 1, 2).foreach(columnValue => { + if (columnValue != 0) { + sql(s"insert into $t(col) values($columnValue)") + } + testCases.foreach(testCase => { + val query = sql(testCase.query) + checkAnswer(query, testCase.resultSeq(columnValue)) + val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst { + case _: Expand => true + }.nonEmpty + assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan) + }) + }) + } + } + + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { + val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c") + withTempView("v1") { + data.createOrReplaceTempView("v1") + val df = + sql("""SELECT + | ROUND(SUM(b), 6) AS sum1, + | COUNT(DISTINCT a) AS count1, + | COUNT(DISTINCT c) AS count2 + |FROM ( + | SELECT + | 6 AS gb, + | * + | FROM v1 + |) + |GROUP BY a, gb + |""".stripMargin) + checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil) + } + } } case class B(c: Option[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 4f25642906628..95c5c5590e504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -31,6 +31,15 @@ import org.apache.spark.sql.types.ArrayType class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { import testImplicits._ + test("ArrayTransform with scan input") { + withTempPath { f => + spark.sql("select array(array(1, null, 3), array(4, 5, null), array(null, 8, 9)) as a") + .write.parquet(f.getAbsolutePath) + val df = spark.read.parquet(f.getAbsolutePath).selectExpr("transform(a, (x, i) -> x)") + checkAnswer(df, Row(Seq(Seq(1, null, 3), Seq(4, 5, null), Seq(null, 8, 9)))) + } + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 1ac1dda374fa7..6c1ca94a03079 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -547,4 +547,55 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession } } } + + test("SPARK-49836 using window fn with window as parameter should preserve parent operator") { + withTempView("clicks") { + val df = Seq( + // small window: [00:00, 01:00), user1, 2 + ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"), + // small window: [01:00, 02:00), user2, 2 + ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"), + // small window: [03:00, 04:00), user1, 1 + ("2024-09-30 00:03:30", "user1"), + // small window: [11:00, 12:00), user1, 3 + ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"), + ("2024-09-30 00:11:45", "user1") + ).toDF("eventTime", "userId") + + // session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 / + // (12:00, 12:05), user1, 3 + + df.createOrReplaceTempView("clicks") + + val aggregatedData = spark.sql( + """ + |SELECT + | userId, + | avg(cpu_large.numClicks) AS clicksPerSession + |FROM + |( + | SELECT + | session_window(small_window, '5 minutes') AS session, + | userId, + | sum(numClicks) AS numClicks + | FROM + | ( + | SELECT + | window(eventTime, '1 minute') AS small_window, + | userId, + | count(*) AS numClicks + | FROM clicks + | GROUP BY window, userId + | ) cpu_small + | GROUP BY session_window, userId + |) cpu_large + |GROUP BY userId + |""".stripMargin) + + checkAnswer( + aggregatedData, + Seq(Row("user1", 3), Row("user2", 2)) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2eba9f1810982..7ee18df375616 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} @@ -368,20 +368,6 @@ class DataFrameSuite extends QueryTest Row("a", Seq("a"), 1) :: Nil) } - test("more than one generator in SELECT clause") { - val df = Seq((Array("a"), 1)).toDF("a", "b") - - checkError( - exception = intercept[AnalysisException] { - df.select(explode($"a").as("a"), explode($"a").as("b")) - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(a)\", \"explode(a)\"")) - } - test("sort after generate with join=true") { val df = Seq((Array("a"), 1)).toDF("a", "b") @@ -2258,6 +2244,20 @@ class DataFrameSuite extends QueryTest assert(newConstraints === newExpectedConstraints) } + test("SPARK-46794: exclude subqueries from LogicalRDD constraints") { + withTempDir { checkpointDir => + val subquery = + new Column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan)) + val df = spark.range(1000).filter($"id" === subquery) + assert(df.logicalPlan.constraints.exists(_.exists(_.isInstanceOf[ScalarSubquery]))) + + spark.sparkContext.setCheckpointDir(checkpointDir.getAbsolutePath) + val checkpointedDf = df.checkpoint() + assert(!checkpointedDf.logicalPlan.constraints + .exists(_.exists(_.isInstanceOf[ScalarSubquery]))) + } + } + test("SPARK-10656: completely support special chars") { val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") checkAnswer(df.select(df("*")), Row(1, "a")) @@ -3636,6 +3636,15 @@ class DataFrameSuite extends QueryTest assert(row.getInt(0).toString == row.getString(2)) assert(row.getInt(0).toString == row.getString(3)) } + + val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value")))) + val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback + df.select(v3, v3, v4, v4).collect().foreach { row => + assert(row.getMap(0).toString() == row.getMap(1).toString()) + val expectedString = s"keys: [key], values: [${row.getMap(0).get("key").get}]" + assert(row.getString(2) == s"""\"$expectedString\"""") + assert(row.getString(3) == s"""\"$expectedString\"""") + } } test("SPARK-41219: IntegralDivide use decimal(1, 0) to represent 0") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 6ee173bc6af67..c52d428cd5dd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.sql.Timestamp import java.time.LocalDateTime import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -714,4 +715,56 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { ) } } + + test("SPARK-49836 using window fn with window as parameter should preserve parent operator") { + withTempView("clicks") { + val df = Seq( + // small window: [00:00, 01:00), user1, 2 + ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"), + // small window: [01:00, 02:00), user2, 2 + ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"), + // small window: [07:00, 08:00), user1, 1 + ("2024-09-30 00:07:00", "user1"), + // small window: [11:00, 12:00), user1, 3 + ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"), + ("2024-09-30 00:11:45", "user1") + ).toDF("eventTime", "userId") + + // large window: [00:00, 10:00), user1, 3, [00:00, 10:00), user2, 2, [10:00, 20:00), user1, 3 + + df.createOrReplaceTempView("clicks") + + val aggregatedData = spark.sql( + """ + |SELECT + | cpu_large.large_window.end AS timestamp, + | avg(cpu_large.numClicks) AS avgClicksPerUser + |FROM + |( + | SELECT + | window(small_window, '10 minutes') AS large_window, + | userId, + | sum(numClicks) AS numClicks + | FROM + | ( + | SELECT + | window(eventTime, '1 minute') AS small_window, + | userId, + | count(*) AS numClicks + | FROM clicks + | GROUP BY window, userId + | ) cpu_small + | GROUP BY window, userId + |) cpu_large + |GROUP BY timestamp + |""".stripMargin) + + checkAnswer( + aggregatedData, + Seq( + Row(Timestamp.valueOf("2024-09-30 00:10:00"), 2.5), + Row(Timestamp.valueOf("2024-09-30 00:20:00"), 3)) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index a57e927ba8427..47a311c71d55d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -819,6 +819,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest lead($"value", 1, null, true).over(window), lead($"value", 2, null, true).over(window), lead($"value", 3, null, true).over(window), + // offset > rowCount: SPARK-45430 + lead($"value", 100, null, true).over(window), lead(concat($"value", $"key"), 1, null, true).over(window), lag($"value", 1).over(window), lag($"value", 2).over(window), @@ -826,27 +828,29 @@ class DataFrameWindowFunctionsSuite extends QueryTest lag($"value", 1, null, true).over(window), lag($"value", 2, null, true).over(window), lag($"value", 3, null, true).over(window), + // abs(offset) > rowCount: SPARK-45430 + lag($"value", -100, null, true).over(window), lag(concat($"value", $"key"), 1, null, true).over(window)) .orderBy($"order"), Seq( - Row("a", 0, null, "x", null, null, "x", "y", "z", "xa", - null, null, null, null, null, null, null), - Row("a", 1, "x", null, null, "x", "y", "z", "v", "ya", - null, null, "x", null, null, null, null), - Row("b", 2, null, null, "y", null, "y", "z", "v", "ya", - "x", null, null, "x", null, null, "xa"), - Row("c", 3, null, "y", null, null, "y", "z", "v", "ya", - null, "x", null, "x", null, null, "xa"), - Row("a", 4, "y", null, "z", "y", "z", "v", null, "za", - null, null, "y", "x", null, null, "xa"), - Row("b", 5, null, "z", "v", null, "z", "v", null, "za", - "y", null, null, "y", "x", null, "ya"), - Row("a", 6, "z", "v", null, "z", "v", null, null, "va", - null, "y", "z", "y", "x", null, "ya"), - Row("a", 7, "v", null, null, "v", null, null, null, null, - "z", null, "v", "z", "y", "x", "za"), - Row("a", 8, null, null, null, null, null, null, null, null, - "v", "z", null, "v", "z", "y", "va"))) + Row("a", 0, null, "x", null, null, "x", "y", "z", null, "xa", + null, null, null, null, null, null, null, null), + Row("a", 1, "x", null, null, "x", "y", "z", "v", null, "ya", + null, null, "x", null, null, null, null, null), + Row("b", 2, null, null, "y", null, "y", "z", "v", null, "ya", + "x", null, null, "x", null, null, null, "xa"), + Row("c", 3, null, "y", null, null, "y", "z", "v", null, "ya", + null, "x", null, "x", null, null, null, "xa"), + Row("a", 4, "y", null, "z", "y", "z", "v", null, null, "za", + null, null, "y", "x", null, null, null, "xa"), + Row("b", 5, null, "z", "v", null, "z", "v", null, null, "za", + "y", null, null, "y", "x", null, null, "ya"), + Row("a", 6, "z", "v", null, "z", "v", null, null, null, "va", + null, "y", "z", "y", "x", null, null, "ya"), + Row("a", 7, "v", null, null, "v", null, null, null, null, null, + "z", null, "v", "z", "y", "x", null, "za"), + Row("a", 8, null, null, null, null, null, null, null, null, null, + "v", "z", null, "v", "z", "y", null, "va"))) } test("lag - Offset expression must be a literal") { @@ -1521,4 +1525,116 @@ class DataFrameWindowFunctionsSuite extends QueryTest assert(windows.size === 1) } } + + test("SPARK-45543: InferWindowGroupLimit causes bug " + + "if the other window functions haven't the same window frame as the rank-like functions") { + val df = Seq( + (1, "Dave", 1, 2020), + (2, "Dave", 1, 2021), + (3, "Dave", 2, 2022), + (4, "Dave", 3, 2023), + (5, "Dave", 3, 2024), + (6, "Mark", 2, 2022), + (7, "Mark", 3, 2023), + (8, "Mark", 3, 2024), + (9, "Amy", 6, 2021), + (10, "Amy", 5, 2022), + (11, "Amy", 6, 2023), + (12, "Amy", 7, 2024), + (13, "John", 7, 2024)).toDF("id", "name", "score", "year") + + val window = Window.partitionBy($"year").orderBy($"score".desc) + val window2 = window.rowsBetween(Window.unboundedPreceding, Window.currentRow) + val window3 = window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + Seq(-1, 100).foreach { threshold => + withSQLConf(SQLConf.WINDOW_GROUP_LIMIT_THRESHOLD.key -> threshold.toString) { + // The other window functions have the same window frame as the rank-like functions. + // df2, df3 and df4 can apply InferWindowGroupLimit + val df2 = df + .withColumn("rn", row_number().over(window)) + .withColumn("all_scores", collect_list($"score").over(window2)) + .sort($"year") + + checkAnswer(df2.filter("rn=1"), Seq( + Row(1, "Dave", 1, 2020, 1, Array(1)), + Row(9, "Amy", 6, 2021, 1, Array(6)), + Row(10, "Amy", 5, 2022, 1, Array(5)), + Row(11, "Amy", 6, 2023, 1, Array(6)), + Row(12, "Amy", 7, 2024, 1, Array(7)) + )) + + val df3 = df + .withColumn("rank", rank().over(window)) + .withColumn("all_scores", collect_list($"score").over(window2)) + .sort($"year") + + checkAnswer(df3.filter("rank=2"), Seq( + Row(2, "Dave", 1, 2021, 2, Array(6, 1)), + Row(3, "Dave", 2, 2022, 2, Array(5, 2)), + Row(6, "Mark", 2, 2022, 2, Array(5, 2, 2)), + Row(4, "Dave", 3, 2023, 2, Array(6, 3)), + Row(7, "Mark", 3, 2023, 2, Array(6, 3, 3)) + )) + + val df4 = df + .withColumn("rank", dense_rank().over(window)) + .withColumn("all_scores", collect_list($"score").over(window2)) + .sort($"year") + + checkAnswer(df4.filter("rank=2"), Seq( + Row(2, "Dave", 1, 2021, 2, Array(6, 1)), + Row(3, "Dave", 2, 2022, 2, Array(5, 2)), + Row(6, "Mark", 2, 2022, 2, Array(5, 2, 2)), + Row(4, "Dave", 3, 2023, 2, Array(6, 3)), + Row(7, "Mark", 3, 2023, 2, Array(6, 3, 3)), + Row(5, "Dave", 3, 2024, 2, Array(7, 7, 3)), + Row(8, "Mark", 3, 2024, 2, Array(7, 7, 3, 3)) + )) + + // The other window functions haven't the same window frame as the rank-like functions. + // df5, df6 and df7 cannot apply InferWindowGroupLimit + val df5 = df + .withColumn("rn", row_number().over(window)) + .withColumn("all_scores", collect_list($"score").over(window3)) + .sort($"year") + + checkAnswer(df5.filter("rn=1"), Seq( + Row(1, "Dave", 1, 2020, 1, Array(1)), + Row(9, "Amy", 6, 2021, 1, Array(6, 1)), + Row(10, "Amy", 5, 2022, 1, Array(5, 2, 2)), + Row(11, "Amy", 6, 2023, 1, Array(6, 3, 3)), + Row(12, "Amy", 7, 2024, 1, Array(7, 7, 3, 3)) + )) + + val df6 = df + .withColumn("rank", rank().over(window)) + .withColumn("all_scores", collect_list($"score").over(window3)) + .sort($"year") + + checkAnswer(df6.filter("rank=2"), Seq( + Row(2, "Dave", 1, 2021, 2, Array(6, 1)), + Row(3, "Dave", 2, 2022, 2, Array(5, 2, 2)), + Row(6, "Mark", 2, 2022, 2, Array(5, 2, 2)), + Row(4, "Dave", 3, 2023, 2, Array(6, 3, 3)), + Row(7, "Mark", 3, 2023, 2, Array(6, 3, 3)) + )) + + val df7 = df + .withColumn("rank", dense_rank().over(window)) + .withColumn("all_scores", collect_list($"score").over(window3)) + .sort($"year") + + checkAnswer(df7.filter("rank=2"), Seq( + Row(2, "Dave", 1, 2021, 2, Array(6, 1)), + Row(3, "Dave", 2, 2022, 2, Array(5, 2, 2)), + Row(6, "Mark", 2, 2022, 2, Array(5, 2, 2)), + Row(4, "Dave", 3, 2023, 2, Array(6, 3, 3)), + Row(7, "Mark", 3, 2023, 2, Array(6, 3, 3)), + Row(5, "Dave", 3, 2024, 2, Array(7, 7, 3, 3)), + Row(8, "Mark", 3, 2024, 2, Array(7, 7, 3, 3)) + )) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c967540541a5c..f32b32ffc5a5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -269,6 +270,13 @@ class DatasetSuite extends QueryTest (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } + test("SPARK-45896: seq of option of seq") { + val ds = Seq(DataSeqOptSeq(Seq(Some(Seq(0))))).toDS() + checkDataset( + ds, + DataSeqOptSeq(Seq(Some(List(0))))) + } + test("select") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( @@ -2496,6 +2504,18 @@ class DatasetSuite extends QueryTest assert(result == expected) } + test("SPARK-47385: Tuple encoder with Option inputs") { + implicit val enc: Encoder[(SingleData, Option[SingleData])] = + Encoders.tuple(Encoders.product[SingleData], Encoders.product[Option[SingleData]]) + + val input = Seq( + (SingleData(1), Some(SingleData(1))), + (SingleData(2), None) + ) + val ds = spark.createDataFrame(input).as[(SingleData, Option[SingleData])] + checkDataset(ds, input: _*) + } + test("SPARK-43124: Show does not trigger job execution on CommandResults") { withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") { withTable("t1") { @@ -2535,6 +2555,38 @@ class DatasetSuite extends QueryTest checkDataset(ds.filter(f(col("_1"))), Tuple1(ValueClass(2))) } + + test("SPARK-45386: persist with StorageLevel.NONE should give correct count") { + val ds = Seq(1, 2).toDS().persist(StorageLevel.NONE) + assert(ds.count() == 2) + } + + test("SPARK-45592: Coaleasced shuffle read is not compatible with hash partitioning") { + withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "20", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "2000") { + val ee = spark.range(0, 1000, 1, 5).map(l => (l, l - 1)).toDF() + .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + ee.count() + + // `minNbrs1` will start with 20 partitions and without the fix would coalesce to ~10 + // partitions. + val minNbrs1 = ee + .groupBy("_2").agg(min(col("_1")).as("min_number")) + .select(col("_2") as "_1", col("min_number")) + .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + minNbrs1.count() + + // shuffle on `ee` will start with 2 partitions, smaller than `minNbrs1`'s partition num, + // and `EnsureRequirements` will change its partition num to `minNbrs1`'s partition num. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + val join = ee.join(minNbrs1, "_1") + assert(join.count() == 999) + } + } + } + } class DatasetLargeResultCollectingSuite extends QueryTest @@ -2609,6 +2661,8 @@ case class ClassNullableData(a: String, b: Integer) case class NestedStruct(f: ClassData) case class DeepNestedStruct(f: NestedStruct) +case class DataSeqOptSeq(a: Seq[Option[Seq[Int]]]) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index abec582d43a30..7c285759fcd9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -432,7 +432,6 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { }, errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", parameters = Map( - "clause" -> "aggregate", "num" -> "2", "generators" -> ("\"explode(array(min(c2), max(c2)))\", " + "\"posexplode(array(min(c2), max(c2)))\""))) @@ -536,6 +535,39 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(1, 1) :: Row(1, 2) :: Row(2, 2) :: Row(2, 3) :: Row(3, null) :: Nil) } + + test("SPARK-45171: Handle evaluated nondeterministic expression") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val df = sql("select explode(array(rand(0)))") + checkAnswer(df, Row(0.7604953758285915d)) + } + } + + test("SPARK-47241: two generator functions in SELECT") { + def testTwoGenerators(needImplicitCast: Boolean): Unit = { + val df = sql( + s""" + |SELECT + |explode(array('a', 'b')) as c1, + |explode(array(0L, ${if (needImplicitCast) "0L + 1" else "1L"})) as c2 + |""".stripMargin) + checkAnswer(df, Seq(Row("a", 0L), Row("a", 1L), Row("b", 0L), Row("b", 1L))) + } + testTwoGenerators(needImplicitCast = true) + testTwoGenerators(needImplicitCast = false) + } + + test("SPARK-47241: generator function after wildcard in SELECT") { + val df = sql( + s""" + |SELECT *, explode(array('a', 'b')) as c1 + |FROM + |( + | SELECT id FROM range(1) GROUP BY 1 + |) + |""".stripMargin) + checkAnswer(df, Seq(Row(0, "a"), Row(0, "b"))) + } } case class EmptyGenerator() extends Generator with LeafLike[Expression] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 14f1fb27906a1..4d256154c8574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ListBuffer import org.mockito.Mockito._ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.internal.config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} @@ -32,11 +33,11 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION} import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} import org.apache.spark.sql.types.StructType import org.apache.spark.tags.SlowSQLTest @@ -1729,4 +1730,60 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan checkAnswer(joined, expected) } + + test("SPARK-45882: BroadcastHashJoinExec propagate partitioning should respect " + + "CoalescedHashPartitioning") { + val cached = spark.sql( + """ + |select /*+ broadcast(testData) */ key, value, a + |from testData join ( + | select a from testData2 group by a + |)tmp on key = a + |""".stripMargin).cache() + try { + val df = cached.groupBy("key").count() + val expected = Seq(Row(1, 1), Row(2, 1), Row(3, 1)) + assert(find(df.queryExecution.executedPlan) { + case _: ShuffleExchangeLike => true + case _ => false + }.size == 1, df.queryExecution) + checkAnswer(df, expected) + assert(find(df.queryExecution.executedPlan) { + case _: ShuffleExchangeLike => true + case _ => false + }.isEmpty, df.queryExecution) + } finally { + cached.unpersist() + } + } +} + +class ThreadLeakInSortMergeJoinSuite + extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + setupTestData() + override protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() + new TestSparkSession( + sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20)) + } + + test("SPARK-47146: thread leak when doing SortMergeJoin (with spill)") { + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { + + assertSpilled(sparkContext, "inner join") { + sql("SELECT * FROM testData JOIN testData2 ON key = a").collect() + } + + val readAheadThread = Thread.getAllStackTraces.keySet().asScala + .find { + _.getName.startsWith("read-ahead") + } + assert(readAheadThread.isEmpty) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index a76e102fe913f..0d30d8e95ee7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -1144,6 +1144,38 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-48863: parse object as an array with partial results enabled") { + val schema = StructType(StructField("a", StringType) :: StructField("c", IntegerType) :: Nil) + + // Value can be parsed correctly and should return the same result with or without the flag. + Seq(false, true).foreach { enabled => + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> s"${enabled}") { + checkAnswer( + Seq("""{"a": "b", "c": 1}""").toDF("c0") + .select(from_json($"c0", ArrayType(schema))), + Row(Seq(Row("b", 1))) + ) + } + } + + // Value does not match the schema. + val df = Seq("""{"a": "b", "c": "1"}""").toDF("c0") + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + checkAnswer( + df.select(from_json($"c0", ArrayType(schema))), + Row(Seq(Row("b", null))) + ) + } + + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "false") { + checkAnswer( + df.select(from_json($"c0", ArrayType(schema))), + Row(null) + ) + } + } + test("SPARK-33270: infers schema for JSON field with spaces and pass them to from_json") { val in = Seq("""{"a b": 1}""").toDS() val out = in.select(from_json($"value", schema_of_json("""{"a b": 100}""")) as "parsed") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 0adb89c3a9eaf..ba04e3b691a1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -262,6 +262,17 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-44973: conv must allocate enough space for all digits plus negative sign") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { + val df = Seq( + ((BigInt(Long.MaxValue) + 1).toString(16)), + (BigInt(Long.MinValue).toString(16)) + ).toDF("num") + checkAnswer(df.select(conv($"num", 16, -2)), + Seq(Row(BigInt(Long.MinValue).toString(2)), Row(BigInt(Long.MinValue).toString(2)))) + } + } + test("floor") { testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong) // testOneToOneMathFunction does not validate the resulting data type diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index a72c9a600adea..773b7041dee5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -502,4 +502,61 @@ class ParametersSuite extends QueryTest with SharedSparkSession { start = 24, stop = 36)) } + + test("SPARK-49017: bind named parameters with IDENTIFIER clause") { + withTable("testtab") { + // Create table + spark.sql("create table testtab (id int, name string) using parquet") + + // Insert into table using single param + spark.sql("insert into identifier(:tab) values(1, 'test1')", Map("tab" -> "testtab")) + + // Select from table using param + checkAnswer(spark.sql("select * from identifier(:tab)", Map("tab" -> "testtab")), + Seq(Row(1, "test1"))) + + // Insert into table using multiple params + spark.sql("insert into identifier(:tab) values(2, :name)", + Map("tab" -> "testtab", "name" -> "test2")) + + // Select from table using param + checkAnswer(sql("select * from testtab"), Seq(Row(1, "test1"), Row(2, "test2"))) + + // Insert into table using multiple params and idents + sql("insert into testtab values(2, 'test3')") + + // Select from table using param + checkAnswer(spark.sql("select identifier(:col) from identifier(:tab) where :name == name", + Map("tab" -> "testtab", "name" -> "test2", "col" -> "id")), Seq(Row(2))) + } + } + + test("SPARK-49017: bind positional parameters with IDENTIFIER clause") { + withTable("testtab") { + // Create table + spark.sql("create table testtab (id int, name string) using parquet") + + // Insert into table using single param + spark.sql("insert into identifier(?) values(1, 'test1')", + Array("testtab")) + + // Select from table using param + checkAnswer(spark.sql("select * from identifier(?)", Array("testtab")), + Seq(Row(1, "test1"))) + + // Insert into table using multiple params + spark.sql("insert into identifier(?) values(2, ?)", + Array("testtab", "test2")) + + // Select from table using param + checkAnswer(sql("select * from testtab"), Seq(Row(1, "test1"), Row(2, "test2"))) + + // Insert into table using multiple params and idents + sql("insert into testtab values(2, 'test3')") + + // Select from table using param + checkAnswer(spark.sql("select identifier(?) from identifier(?) where ? == name", + Array("id", "testtab", "test2")), Seq(Row(2))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index b5b3492269415..c26757c9cff70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -256,9 +256,11 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { protected def testQuery(tpcdsGroup: String, query: String, suffix: String = ""): Unit = { val queryString = resourceToString(s"$tpcdsGroup/$query.sql", classLoader = Thread.currentThread().getContextClassLoader) - // Disable char/varchar read-side handling for better performance. - withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { + withSQLConf( + // Disable char/varchar read-side handling for better performance. + SQLConf.READ_SIDE_CHAR_PADDING.key -> "false", + SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { val qe = sql(queryString).queryExecution val plan = qe.executedPlan val explain = normalizeLocation(normalizeIds(qe.explainString(FormattedMode))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala index 29b2796d25aa4..79b2f517b060b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala @@ -76,57 +76,59 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { } test("INSERT into partitioned tables") { - sql("create table t(c1 int, c2 int, c3 int, c4 int) using parquet partitioned by (c3, c4)") - - // INSERT without static partitions - checkError( - exception = intercept[AnalysisException] { - sql("insert into t values (1, 2, 3)") - }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", - parameters = Map( - "tableName" -> "`spark_catalog`.`default`.`t`", - "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", - "dataColumns" -> "`col1`, `col2`, `col3`")) - - // INSERT without static partitions but with column list - sql("truncate table t") - sql("insert into t (c2, c1, c4) values (1, 2, 3)") - checkAnswer(spark.table("t"), Row(2, 1, null, 3)) - - // INSERT with static partitions - sql("truncate table t") - checkError( - exception = intercept[AnalysisException] { - sql("insert into t partition(c3=3, c4=4) values (1)") - }, - errorClass = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", - parameters = Map( - "tableName" -> "`spark_catalog`.`default`.`t`", - "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", - "dataColumns" -> "`col1`", - "staticPartCols" -> "`c3`, `c4`")) - - // INSERT with static partitions and with column list - sql("truncate table t") - sql("insert into t partition(c3=3, c4=4) (c2) values (1)") - checkAnswer(spark.table("t"), Row(null, 1, 3, 4)) - - // INSERT with partial static partitions - sql("truncate table t") - checkError( - exception = intercept[AnalysisException] { - sql("insert into t partition(c3=3, c4) values (1, 2)") - }, - errorClass = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", - parameters = Map( - "tableName" -> "`spark_catalog`.`default`.`t`", - "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", - "dataColumns" -> "`col1`, `col2`", - "staticPartCols" -> "`c3`")) - - // INSERT with partial static partitions and with column list is not allowed - intercept[AnalysisException](sql("insert into t partition(c3=3, c4) (c1) values (1, 4)")) + withTable("t") { + sql("create table t(c1 int, c2 int, c3 int, c4 int) using parquet partitioned by (c3, c4)") + + // INSERT without static partitions + checkError( + exception = intercept[AnalysisException] { + sql("insert into t values (1, 2, 3)") + }, + errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + parameters = Map( + "tableName" -> "`spark_catalog`.`default`.`t`", + "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", + "dataColumns" -> "`col1`, `col2`, `col3`")) + + // INSERT without static partitions but with column list + sql("truncate table t") + sql("insert into t (c2, c1, c4) values (1, 2, 3)") + checkAnswer(spark.table("t"), Row(2, 1, null, 3)) + + // INSERT with static partitions + sql("truncate table t") + checkError( + exception = intercept[AnalysisException] { + sql("insert into t partition(c3=3, c4=4) values (1)") + }, + errorClass = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", + parameters = Map( + "tableName" -> "`spark_catalog`.`default`.`t`", + "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", + "dataColumns" -> "`col1`", + "staticPartCols" -> "`c3`, `c4`")) + + // INSERT with static partitions and with column list + sql("truncate table t") + sql("insert into t partition(c3=3, c4=4) (c2) values (1)") + checkAnswer(spark.table("t"), Row(null, 1, 3, 4)) + + // INSERT with partial static partitions + sql("truncate table t") + checkError( + exception = intercept[AnalysisException] { + sql("insert into t partition(c3=3, c4) values (1, 2)") + }, + errorClass = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", + parameters = Map( + "tableName" -> "`spark_catalog`.`default`.`t`", + "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", + "dataColumns" -> "`col1`, `col2`", + "staticPartCols" -> "`c3`")) + + // INSERT with partial static partitions and with column list is not allowed + intercept[AnalysisException](sql("insert into t partition(c3=3, c4) (c1) values (1, 4)")) + } } test("SPARK-43085: Column DEFAULT assignment for target tables with multi-part names") { @@ -213,4 +215,25 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-49054: Create table with current_user() default") { + val tableName = "test_current_user" + val user = spark.sparkContext.sparkUser + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i int, s string default current_user()) USING parquet") + sql(s"INSERT INTO $tableName (i) VALUES ((0))") + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(0, user))) + } + } + + test("SPARK-49054: Alter table with current_user() default") { + val tableName = "test_current_user" + val user = spark.sparkContext.sparkUser + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i int, s string) USING parquet") + sql(s"ALTER TABLE $tableName ALTER COLUMN s SET DEFAULT current_user()") + sql(s"INSERT INTO $tableName (i) VALUES ((0))") + checkAnswer(sql(s"SELECT * FROM $tableName"), Seq(Row(0, user))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cfeccbdf648c2..793a0da6a8622 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{AccumulatorSuite, SPARK_DOC_ROOT, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Hex} +import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRow, Hex} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} @@ -1430,6 +1430,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-49200: Fix null type non-codegen ordering exception") { + withSQLConf( + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString, + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.EliminateSorts") { + checkAnswer( + sql("SELECT * FROM range(3) ORDER BY array(null)"), + Seq(Row(0), Row(1), Row(2))) + } + } + test("SPARK-8837: use keyword in column name") { withTempView("t") { val df = Seq(1 -> "a").toDF("count", "sort") @@ -4700,6 +4711,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val df6 = df3.join(df2, col("df3.zaak_id") === col("df2.customer_id"), "outer") df5.crossJoin(df6) } + + test("SPARK-49743: OptimizeCsvJsonExpr does not change schema when pruning struct") { + val df = sql(""" + | SELECT + | from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').a, + | from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').A + | FROM + | range(3) as t + |""".stripMargin) + val expectedAnswer = Seq( + Row(Array(0), Array(0)), Row(Array(1), Array(1)), Row(Array(2), Array(2))) + checkAnswer(df, expectedAnswer) + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 21518085ca4c5..8b4ac474f8753 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -516,6 +518,31 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } } } + + test("SPARK-46170: Support inject adaptive query post planner strategy rules in " + + "SparkSessionExtensions") { + val extensions = create { extensions => + extensions.injectQueryPostPlannerStrategyRule(_ => MyQueryPostPlannerStrategyRule) + } + withSession(extensions) { session => + assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules + .contains(MyQueryPostPlannerStrategyRule)) + import session.implicits._ + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { + val input = Seq(10, 20, 10).toDF("c1") + val df = input.groupBy("c1").count() + df.collect() + assert(df.rdd.partitions.length == 1) + assert(collectFirst(df.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true + }.isDefined) + assert(collectFirst(df.queryExecution.executedPlan) { + case _: SortExec => true + }.isDefined) + } + } + } } case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { @@ -1190,3 +1217,14 @@ object RequireAtLeaseTwoPartitions extends Rule[SparkPlan] { } } } + +object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Partial) => + ShuffleExchangeExec(SinglePartition, h) + case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Final) => + SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, Ascending)), false, h) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 04e47ac4a1132..4468edc483735 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeMap import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SQLTestUtils @@ -269,7 +270,8 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils def getTableFromCatalogCache(tableName: String): LogicalPlan = { val catalog = spark.sessionState.catalog - val qualifiedTableName = QualifiedTableName(catalog.getCurrentDatabase, tableName) + val qualifiedTableName = QualifiedTableName( + CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, tableName) catalog.getCachedTable(qualifiedTableName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 88c9e15570e30..fa1a64460fcb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.{SPARK_DOC_ROOT, SparkRuntimeException} +import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.execution.FormattedMode import org.apache.spark.sql.functions._ @@ -1173,6 +1173,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(try_to_number(col("a"), lit("$99.99"))), Seq(Row(78.12))) } + test("SPARK-47646: try_to_number should return NULL for malformed input") { + val df = spark.createDataset(spark.sparkContext.parallelize(Seq("11"))) + checkAnswer(df.select(try_to_number($"value", lit("$99.99"))), Seq(Row(null))) + } + test("SPARK-44905: stateful lastRegex causes NullPointerException on eval for regexp_replace") { val df = sql("select regexp_replace('', '[a\\\\d]{0, 2}', 'x')") intercept[SparkRuntimeException](df.queryExecution.optimizedPlan) @@ -1186,4 +1191,13 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { ) ) } + + test("SPARK-48806: url_decode exception") { + val e = intercept[SparkIllegalArgumentException] { + sql("select url_decode('https%3A%2F%2spark.apache.org')").collect() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException] && + e.getCause.getMessage + .startsWith("URLDecoder: Illegal hex characters in escape (%) pattern - ")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index d235d2a15fea3..260c992f1aed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2212,6 +2212,24 @@ class SubquerySuite extends QueryTest } } + test("SPARK-49819: Do not collapse projects with exist subqueries") { + withTempView("v") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("v") + checkAnswer( + sql(""" + |SELECT m, CASE WHEN EXISTS (SELECT SUM(c2) FROM v WHERE c1 = m) THEN 1 ELSE 0 END + |FROM (SELECT MIN(c2) AS m FROM v) + |""".stripMargin), + Row(1, 1) :: Nil) + checkAnswer( + sql(""" + |SELECT c, CASE WHEN EXISTS (SELECT SUM(c2) FROM v WHERE c1 = c) THEN 1 ELSE 0 END + |FROM (SELECT c1 AS c FROM v GROUP BY c1) + |""".stripMargin), + Row(0, 1) :: Row(1, 1) :: Nil) + } + } + test("SPARK-37199: deterministic in QueryPlan considers subquery") { val deterministicQueryPlan = sql("select (select 1 as b) as b") .queryExecution.executedPlan @@ -2712,4 +2730,74 @@ class SubquerySuite extends QueryTest expected) } } + + test("SPARK-45584: subquery execution should not fail with ORDER BY and LIMIT") { + withTable("t1") { + sql( + """ + |CREATE TABLE t1 USING PARQUET + |AS SELECT * FROM VALUES + |(1, "a"), + |(2, "a"), + |(3, "a") t(id, value) + |""".stripMargin) + val df = sql( + """ + |WITH t2 AS ( + | SELECT * FROM t1 ORDER BY id + |) + |SELECT *, (SELECT COUNT(*) FROM t2) FROM t2 LIMIT 10 + |""".stripMargin) + // This should not fail with IllegalArgumentException. + checkAnswer( + df, + Row(1, "a", 3) :: Row(2, "a", 3) :: Row(3, "a", 3) :: Nil) + } + } + + test("SPARK-45580: Handle case where a nested subquery becomes an existence join") { + withTempView("t1", "t2", "t3") { + Seq((1), (2), (3), (7)).toDF("a").persist().createOrReplaceTempView("t1") + Seq((1), (2), (3)).toDF("c1").persist().createOrReplaceTempView("t2") + Seq((3), (9)).toDF("col1").persist().createOrReplaceTempView("t3") + + val query1 = + """ + |SELECT * + |FROM t1 + |WHERE EXISTS ( + | SELECT c1 + | FROM t2 + | WHERE a = c1 + | OR a IN (SELECT col1 FROM t3) + |)""".stripMargin + val df1 = sql(query1) + checkAnswer(df1, Row(1) :: Row(2) :: Row(3) :: Nil) + + val query2 = + """ + |SELECT * + |FROM t1 + |WHERE a IN ( + | SELECT c1 + | FROM t2 + | where a IN (SELECT col1 FROM t3) + |)""".stripMargin + val df2 = sql(query2) + checkAnswer(df2, Row(3)) + + val query3 = + """ + |SELECT * + |FROM t1 + |WHERE NOT EXISTS ( + | SELECT c1 + | FROM t2 + | WHERE a = c1 + | OR a IN (SELECT col1 FROM t3) + |)""".stripMargin + val df3 = sql(query3) + checkAnswer(df3, Row(7)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e54bda1acef59..9f8e979e3fba7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -1067,4 +1067,22 @@ class UDFSuite extends QueryTest with SharedSparkSession { .lookupFunctionInfo(FunctionIdentifier("dummyUDF")) assert(expressionInfo.getClassName.contains("org.apache.spark.sql.UDFRegistration$$Lambda")) } + + test("SPARK-47927: Correctly pass null values derived from join to UDF") { + val f = udf[Tuple1[Option[Int]], Tuple1[Option[Int]]](identity) + val ds1 = Seq(1).toDS() + val ds2 = Seq[Int]().toDS() + + checkAnswer( + ds1.join(ds2, ds1("value") === ds2("value"), "left_outer") + .select(f(struct(ds2("value").as("_1")))), + Row(Row(null))) + } + + test("SPARK-47927: ScalaUDF null handling") { + val f = udf[Int, Int](_ + 1) + val df = Seq(Some(1), None).toDF("c") + .select(f($"c").as("f"), f($"f")) + checkAnswer(df, Seq(Row(2, 3), Row(null, null))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 835566238c9c1..9dd20c906535e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -55,8 +55,7 @@ class DataSourceV2DataFrameSessionCatalogSuite "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - val format = spark.sessionState.conf.defaultDataSourceName - sql(s"CREATE TABLE same_name(id LONG) USING $format") + sql(s"CREATE TABLE same_name(id LONG) USING $v2Format") spark.range(10).createTempView("same_name") spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) @@ -88,6 +87,15 @@ class DataSourceV2DataFrameSessionCatalogSuite assert(tableInfo.properties().get("provider") === v2Format) } } + + test("SPARK-49246: saveAsTable with v1 format") { + withTable("t") { + sql("CREATE TABLE t(c INT) USING csv") + val df = spark.range(10).toDF() + df.write.mode(SaveMode.Overwrite).format("csv").saveAsTable("t") + verifyTable("t", df) + } + } } class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] { @@ -110,7 +118,14 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable Option(tables.get(ident)) match { case Some(table) => val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) - val schema = CatalogV2Util.applySchemaChanges(table.schema, changes, None, "ALTER TABLE") + val provider = Option(properties.get("provider")) + + val schema = CatalogV2Util.applySchemaChanges( + table.schema, + changes, + provider, + "ALTER TABLE" + ) // fail if the last column in the schema was dropped if (schema.fields.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala index 95624f3f61c5c..7463eb34d17ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala @@ -71,4 +71,12 @@ class DataSourceV2SQLSessionCatalogSuite sql(s"CREATE EXTERNAL TABLE t (i INT) USING $v2Format TBLPROPERTIES($prop)") } } + + test("SPARK-49152: partition columns should be put at the end") { + withTable("t") { + sql("CREATE TABLE t (c1 INT, c2 INT) USING json PARTITIONED BY (c1)") + // partition columns should be put at the end. + assert(getTableMetadata("default.t").columns().map(_.name()) === Seq("c2", "c1")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 06f5600e0d199..9df4b0932f25d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector +import java.{util => jutil} import java.sql.Timestamp import java.time.{Duration, LocalDate, Period} import java.util.Locale @@ -26,8 +27,9 @@ import scala.concurrent.duration.MICROSECONDS import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase @@ -35,11 +37,12 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, _} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership -import org.apache.spark.sql.connector.expressions.LiteralValue +import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -431,6 +434,25 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-49152: CreateTable should store location as qualified") { + val tbl = "testcat.table_name" + + def testWithLocation(location: String, qualified: String): Unit = { + withTable(tbl) { + sql(s"CREATE TABLE $tbl USING foo LOCATION '$location'") + val loc = catalog("testcat").asTableCatalog + .loadTable(Identifier.of(Array.empty, "table_name")) + .properties().get(TableCatalog.PROP_LOCATION) + assert(loc === qualified) + } + } + + testWithLocation("/absolute/path", "file:/absolute/path") + testWithLocation("s3://host/full/path", "s3://host/full/path") + testWithLocation("relative/path", "relative/path") + testWithLocation("/path/special+ char", "file:/path/special+ char") + } + test("SPARK-37545: CreateTableAsSelect should store location as qualified") { val basicIdentifier = "testcat.table_name" val atomicIdentifier = "testcat_atomic.table_name" @@ -440,7 +462,7 @@ class DataSourceV2SQLSuiteV1Filter "AS SELECT id FROM source") val location = spark.sql(s"DESCRIBE EXTENDED $identifier") .filter("col_name = 'Location'") - .select("data_type").head.getString(0) + .select("data_type").head().getString(0) assert(location === "file:/tmp/foo") } } @@ -457,7 +479,7 @@ class DataSourceV2SQLSuiteV1Filter "AS SELECT id FROM source") val location = spark.sql(s"DESCRIBE EXTENDED $identifier") .filter("col_name = 'Location'") - .select("data_type").head.getString(0) + .select("data_type").head().getString(0) assert(location === "file:/tmp/foo") } } @@ -1356,8 +1378,7 @@ class DataSourceV2SQLSuiteV1Filter val identifier = Identifier.of(Array(), "reservedTest") val location = tableCatalog.loadTable(identifier).properties() .get(TableCatalog.PROP_LOCATION) - assert(location.startsWith("file:") && location.endsWith("foo"), - "path as a table property should not have side effects") + assert(location == "foo", "path as a table property should not have side effects") assert(tableCatalog.loadTable(identifier).properties().get("path") == "bar", "path as a table property should not have side effects") assert(tableCatalog.loadTable(identifier).properties().get("Path") == "noop", @@ -1448,7 +1469,7 @@ class DataSourceV2SQLSuiteV1Filter } checkError(exception, errorClass = "SCHEMA_NOT_FOUND", - parameters = Map("schemaName" -> "`ns1`.`ns2`")) + parameters = Map("schemaName" -> "`testcat`.`ns1`.`ns2`")) } test("SPARK-31100: Use: v2 catalog that does not implement SupportsNameSpaces is used " + @@ -1736,6 +1757,16 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-48709: varchar resolution mismatch for DataSourceV2 CTAS") { + withSQLConf( + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString) { + withTable("testcat.ns.t1", "testcat.ns.t2") { + sql("CREATE TABLE testcat.ns.t1 (d1 string, d2 varchar(200)) USING parquet") + sql("CREATE TABLE testcat.ns.t2 USING foo as select * from testcat.ns.t1") + } + } + } + test("ShowCurrentNamespace: basic tests") { def testShowCurrentNamespace(expectedCatalogName: String, expectedNamespace: String): Unit = { val schema = new StructType() @@ -2064,8 +2095,11 @@ class DataSourceV2SQLSuiteV1Filter exception = e, errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", sqlState = "0A000", - parameters = Map("tableName" -> "`spark_catalog`.`default`.`tbl`", - "operation" -> "REPLACE TABLE")) + parameters = Map( + "tableName" -> "`spark_catalog`.`default`.`tbl`", + "operation" -> "REPLACE TABLE" + ) + ) } test("DeleteFrom: - delete with invalid predicate") { @@ -3014,6 +3048,17 @@ class DataSourceV2SQLSuiteV1Filter sqlState = None, parameters = Map("relationId" -> "`x`")) + checkError( + exception = intercept[AnalysisException] { + sql("SELECT * FROM non_exist VERSION AS OF 1") + }, + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`non_exist`"), + context = ExpectedContext( + fragment = "non_exist", + start = 14, + stop = 22)) + val subquery1 = "SELECT 1 FROM non_exist" checkError( exception = intercept[AnalysisException] { @@ -3131,7 +3176,7 @@ class DataSourceV2SQLSuiteV1Filter val properties = table.properties assert(properties.get(TableCatalog.PROP_PROVIDER) == "parquet") assert(properties.get(TableCatalog.PROP_COMMENT) == "This is a comment") - assert(properties.get(TableCatalog.PROP_LOCATION) == "file:///tmp") + assert(properties.get(TableCatalog.PROP_LOCATION) == "file:/tmp") assert(properties.containsKey(TableCatalog.PROP_OWNER)) assert(properties.get(TableCatalog.PROP_EXTERNAL) == "true") assert(properties.get(s"${TableCatalog.OPTION_PREFIX}from") == "0") @@ -3295,6 +3340,196 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-48286: Add new column with default value which is not foldable") { + val foldableExpressions = Seq("1", "2 + 1") + withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> v2Source) { + withTable("tab") { + spark.sql(s"CREATE TABLE tab (col1 INT DEFAULT 100) USING $v2Source") + val exception = intercept[AnalysisException] { + // Rand function is not foldable + spark.sql(s"ALTER TABLE tab ADD COLUMN col2 DOUBLE DEFAULT rand()") + } + assert(exception.errorClass.get == "INVALID_DEFAULT_VALUE.NOT_CONSTANT") + assert(exception.messageParameters("colName") == "`col2`") + assert(exception.messageParameters("defaultValue") == "rand()") + assert(exception.messageParameters("statement") == "ALTER TABLE") + } + foldableExpressions.foreach(expr => { + withTable("tab") { + spark.sql(s"CREATE TABLE tab (col1 INT DEFAULT 100) USING $v2Source") + spark.sql(s"ALTER TABLE tab ADD COLUMN col2 DOUBLE DEFAULT $expr") + } + }) + } + } + + test("SPARK-49099: Switch current schema with custom spark_catalog") { + // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can + // configure a new implementation. + spark.sessionState.catalogManager.reset() + withSQLConf(V2_SESSION_CATALOG_IMPLEMENTATION.key -> classOf[InMemoryCatalog].getName) { + sql("CREATE DATABASE test_db") + sql("USE test_db") + } + } + + test("SPARK-49183: custom spark_catalog generates location for managed tables") { + // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can + // configure a new implementation. + spark.sessionState.catalogManager.reset() + withSQLConf(V2_SESSION_CATALOG_IMPLEMENTATION.key -> classOf[SimpleDelegatingCatalog].getName) { + withTable("t") { + sql(s"CREATE TABLE t (i INT) USING $v2Format") + val table = catalog(SESSION_CATALOG_NAME).asTableCatalog + .loadTable(Identifier.of(Array("default"), "t")) + assert(!table.properties().containsKey(TableCatalog.PROP_EXTERNAL)) + } + } + } + + test("SPARK-49211: V2 Catalog can support built-in data sources") { + def checkParquet(tableName: String, path: String): Unit = { + withTable(tableName) { + sql("CREATE TABLE " + tableName + + " (name STRING) USING PARQUET LOCATION '" + path + "'") + sql("INSERT INTO " + tableName + " VALUES('Bob')") + val df = sql("SELECT * FROM " + tableName) + assert(df.queryExecution.analyzed.exists { + case LogicalRelation(_: HadoopFsRelation, _, _, _) => true + case _ => false + }) + checkAnswer(df, Row("Bob")) + } + } + + // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can + // configure a new implementation. + val table1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "t") + spark.sessionState.catalogManager.reset() + withSQLConf( + V2_SESSION_CATALOG_IMPLEMENTATION.key -> + classOf[V2CatalogSupportBuiltinDataSource].getName) { + withTempPath { path => + checkParquet(table1.toString, path.getAbsolutePath) + } + } + val table2 = QualifiedTableName("testcat3", "default", "t") + withSQLConf( + "spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) { + withTempPath { path => + checkParquet(table2.toString, path.getAbsolutePath) + } + } + } + + test("SPARK-49211: V2 Catalog support CTAS") { + def checkCTAS(tableName: String, path: String): Unit = { + sql("CREATE TABLE " + tableName + " USING PARQUET LOCATION '" + path + + "' AS SELECT 1, 2, 3") + checkAnswer(sql("SELECT * FROM " + tableName), Row(1, 2, 3)) + } + + // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can + // configure a new implementation. + spark.sessionState.catalogManager.reset() + val table1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "t") + withSQLConf( + V2_SESSION_CATALOG_IMPLEMENTATION.key -> + classOf[V2CatalogSupportBuiltinDataSource].getName) { + withTempPath { path => + checkCTAS(table1.toString, path.getAbsolutePath) + } + } + + val table2 = QualifiedTableName("testcat3", "default", "t") + withSQLConf( + "spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) { + withTempPath { path => + checkCTAS(table2.toString, path.getAbsolutePath) + } + } + } + + test("SPARK-49246: read-only catalog") { + def assertPrivilegeError(f: => Unit, privilege: String): Unit = { + val e = intercept[RuntimeException](f) + assert(e.getMessage.contains(privilege)) + } + + def checkWriteOperations(catalog: String): Unit = { + withSQLConf(s"spark.sql.catalog.$catalog" -> classOf[ReadOnlyCatalog].getName) { + val input = sql("SELECT 1") + val tbl = s"$catalog.default.t" + withTable(tbl) { + sql(s"CREATE TABLE $tbl (i INT)") + val df = sql(s"SELECT * FROM $tbl") + assert(df.collect().isEmpty) + assert(df.schema == new StructType().add("i", "int")) + + assertPrivilegeError(sql(s"INSERT INTO $tbl SELECT 1"), "INSERT") + assertPrivilegeError( + sql(s"INSERT INTO $tbl REPLACE WHERE i = 0 SELECT 1"), "DELETE,INSERT") + assertPrivilegeError(sql(s"INSERT OVERWRITE $tbl SELECT 1"), "DELETE,INSERT") + assertPrivilegeError(sql(s"DELETE FROM $tbl WHERE i = 0"), "DELETE") + assertPrivilegeError(sql(s"UPDATE $tbl SET i = 0"), "UPDATE") + assertPrivilegeError( + sql(s""" + |MERGE INTO $tbl USING (SELECT 1 i) AS source + |ON source.i = $tbl.i + |WHEN MATCHED THEN UPDATE SET * + |WHEN NOT MATCHED THEN INSERT * + |WHEN NOT MATCHED BY SOURCE THEN DELETE + |""".stripMargin), + "DELETE,INSERT,UPDATE" + ) + + assertPrivilegeError(input.write.insertInto(tbl), "INSERT") + assertPrivilegeError(input.write.mode("overwrite").insertInto(tbl), "DELETE,INSERT") + assertPrivilegeError(input.write.mode("append").saveAsTable(tbl), "INSERT") + assertPrivilegeError(input.write.mode("overwrite").saveAsTable(tbl), "DELETE,INSERT") + assertPrivilegeError(input.writeTo(tbl).append(), "INSERT") + assertPrivilegeError(input.writeTo(tbl).overwrite(df.col("i") === 1), "DELETE,INSERT") + assertPrivilegeError(input.writeTo(tbl).overwritePartitions(), "DELETE,INSERT") + } + + // Test CTAS + withTable(tbl) { + assertPrivilegeError(sql(s"CREATE TABLE $tbl AS SELECT 1 i"), "INSERT") + } + withTable(tbl) { + assertPrivilegeError(sql(s"CREATE OR REPLACE TABLE $tbl AS SELECT 1 i"), "INSERT") + } + withTable(tbl) { + assertPrivilegeError(input.write.saveAsTable(tbl), "INSERT") + } + withTable(tbl) { + assertPrivilegeError(input.writeTo(tbl).create(), "INSERT") + } + withTable(tbl) { + assertPrivilegeError(input.writeTo(tbl).createOrReplace(), "INSERT") + } + } + } + // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can + // configure a new implementation. + spark.sessionState.catalogManager.reset() + checkWriteOperations(SESSION_CATALOG_NAME) + checkWriteOperations("read_only_cat") + } + + test("StagingTableCatalog without atomic support") { + withSQLConf("spark.sql.catalog.fakeStagedCat" -> classOf[FakeStagedTableCatalog].getName) { + withTable("fakeStagedCat.t") { + sql("CREATE TABLE fakeStagedCat.t AS SELECT 1 col") + checkAnswer(spark.table("fakeStagedCat.t"), Row(1)) + sql("REPLACE TABLE fakeStagedCat.t AS SELECT 2 col") + checkAnswer(spark.table("fakeStagedCat.t"), Row(2)) + sql("CREATE OR REPLACE TABLE fakeStagedCat.t AS SELECT 1 c1, 2 c2") + checkAnswer(spark.table("fakeStagedCat.t"), Row(1, 2)) + } + } + } + private def testNotSupportedV2Command( sqlCommand: String, sqlParams: String, @@ -3323,3 +3558,126 @@ class FakeV2Provider extends SimpleTableProvider { class ReserveSchemaNullabilityCatalog extends InMemoryCatalog { override def useNullableQuerySchema(): Boolean = false } + +class SimpleDelegatingCatalog extends DelegatingCatalogExtension { + override def createTable( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: jutil.Map[String, String]): Table = { + val newProps = new jutil.HashMap[String, String] + newProps.putAll(properties) + newProps.put(TableCatalog.PROP_LOCATION, "/tmp/test_path") + newProps.put(TableCatalog.PROP_IS_MANAGED_LOCATION, "true") + super.createTable(ident, columns, partitions, newProps) + } +} + + +class V2CatalogSupportBuiltinDataSource extends InMemoryCatalog { + override def createTable( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: jutil.Map[String, String]): Table = { + super.createTable(ident, columns, partitions, properties) + null + } + + override def loadTable(ident: Identifier): Table = { + val superTable = super.loadTable(ident) + val tableIdent = { + TableIdentifier(ident.name(), Some(ident.namespace().head), Some(name)) + } + val uri = CatalogUtils.stringToURI(superTable.properties().get(TableCatalog.PROP_LOCATION)) + val sparkTable = CatalogTable( + tableIdent, + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy( + locationUri = Some(uri), + properties = superTable.properties().asScala.toMap + ), + schema = superTable.schema(), + provider = Some(superTable.properties().get(TableCatalog.PROP_PROVIDER)), + tracksPartitionsInCatalog = false + ) + V1Table(sparkTable) + } +} + +class ReadOnlyCatalog extends InMemoryCatalog { + override def createTable( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: jutil.Map[String, String]): Table = { + super.createTable(ident, columns, partitions, properties) + null + } + + override def loadTable( + ident: Identifier, + writePrivileges: jutil.Set[TableWritePrivilege]): Table = { + throw new RuntimeException("cannot write with " + + writePrivileges.asScala.toSeq.map(_.toString).sorted.mkString(",")) + } +} + +class FakeStagedTableCatalog extends InMemoryCatalog with StagingTableCatalog { + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: jutil.Map[String, String]): StagedTable = { + throw new RuntimeException("shouldn't be called") + } + + override def stageCreate( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: jutil.Map[String, String]): StagedTable = { + super.createTable(ident, columns, partitions, properties) + null + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: jutil.Map[String, String]): StagedTable = { + throw new RuntimeException("shouldn't be called") + } + + override def stageReplace( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: jutil.Map[String, String]): StagedTable = { + super.dropTable(ident) + super.createTable(ident, columns, partitions, properties) + null + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: jutil.Map[String, String]): StagedTable = { + throw new RuntimeException("shouldn't be called") + } + + override def stageCreateOrReplace( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: jutil.Map[String, String]): StagedTable = { + try { + super.dropTable(ident) + } catch { + case _: Throwable => + } + super.createTable(ident, columns, partitions, properties) + null + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 52d0151ee4623..d269290e6162d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -626,6 +626,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } } + + test("SPARK-47463: Pushed down v2 filter with if expression") { + withTempView("t1") { + spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() + .createTempView("t1") + val df = sql("SELECT * FROM t1 WHERE if(i = 1, i, 0) > 0") + val result = df.collect() + assert(result.length == 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 8461f528277c3..71e030f535e9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -330,6 +330,28 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { .add("price", FloatType) .add("time", TimestampType) + test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") { + val cols = new StructType() + .add("id", LongType) + .add("name", StringType) + val buckets = Array(bucket(8, "id")) + + withTable("t1", "t2", "t3") { + Seq("t1", "t2", "t3").foreach { t => + createTable(t, cols, buckets) + sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (2, 'bb'), (3, 'cc')") + } + val df = sql( + """ + |SELECT t1.id, t2.id, t3.name FROM testcat.ns.t1 + |JOIN testcat.ns.t2 ON t1.id = t2.id + |JOIN testcat.ns.t3 ON t1.id = t3.id + |""".stripMargin) + checkAnswer(df, Seq(Row(1, 1, "aa"), Row(2, 2, "bb"), Row(3, 3, "cc"))) + assert(collectShuffles(df.queryExecution.executedPlan).isEmpty) + } + } + test("partitioned join: join with two partition keys and matching & sorted partitions") { val items_partitions = Array(bucket(8, "id"), days("arrive_time")) createTable(items, items_schema, items_partitions) @@ -1095,4 +1117,46 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-45652: SPJ should handle empty partition after dynamic filtering") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => { + withSQLConf( + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { + // The dynamic filtering effectively filtered out all the partitions + val df = sql(s"SELECT p.price from testcat.ns.$items i, testcat.ns.$purchases p " + + "WHERE i.id = p.item_id AND i.price > 50.0") + checkAnswer(df, Seq.empty) + } + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index e7555c23fa4fc..5668e5981910c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -32,6 +32,38 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { import testImplicits._ + test("SPARK-45974: merge into non filter attributes table") { + val tableName: String = "cat.ns1.non_partitioned_table" + withTable(tableName) { + withTempView("source") { + val sourceRows = Seq( + (1, 100, "hr"), + (2, 200, "finance"), + (3, 300, "hr")) + sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source") + + sql(s"CREATE TABLE $tableName (pk INT NOT NULL, salary INT, dep STRING)".stripMargin) + + val df = sql( + s"""MERGE INTO $tableName t + |USING (select * from source) s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = s.salary + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableName"), + Seq( + Row(1, 100, "hr"), // insert + Row(2, 200, "finance"), // insert + Row(3, 300, "hr"))) // insert + } + } + } + test("merge into empty table with NOT MATCHED clause") { withTempView("source") { createTable("pk INT NOT NULL, salary INT, dep STRING") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index 46586c622db79..bd13123d587f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -22,8 +22,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.catalog.CatalogTableType -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, DelegatingCatalogExtension, Identifier, Table, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, DelegatingCatalogExtension, Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -53,14 +52,10 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating if (tables.containsKey(ident)) { tables.get(ident) } else { - // Table was created through the built-in catalog - super.loadTable(ident) match { - case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW => v1Table - case t => - val table = newTable(t.name(), t.schema(), t.partitioning(), t.properties()) - addTable(ident, table) - table - } + // Table was created through the built-in catalog via v1 command, this is OK as the + // `loadTable` should always be invoked, and we set the `tableCreated` to pass validation. + tableCreated.set(true) + super.loadTable(ident) } } @@ -78,22 +73,27 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating schema: StructType, partitions: Array[Transform], properties: java.util.Map[String, String]): Table = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY - val propsWithLocation = if (properties.containsKey(key)) { + val newProps = new java.util.HashMap[String, String]() + newProps.putAll(properties) + if (properties.containsKey(TableCatalog.PROP_LOCATION)) { + newProps.put(TableCatalog.PROP_EXTERNAL, "true") + } + + val propsWithLocation = if (newProps.containsKey(key)) { // Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified. - if (!properties.containsKey(TableCatalog.PROP_LOCATION)) { - val newProps = new java.util.HashMap[String, String]() - newProps.putAll(properties) + if (!newProps.containsKey(TableCatalog.PROP_LOCATION)) { newProps.put(TableCatalog.PROP_LOCATION, "file:/abc") newProps } else { - properties + newProps } } else { - properties + newProps } - val created = super.createTable(ident, schema, partitions, propsWithLocation) - val t = newTable(created.name(), schema, partitions, propsWithLocation) + super.createTable(ident, schema, partitions, propsWithLocation) + val t = newTable(ident.quoted, schema, partitions, propsWithLocation) addTable(ident, t) t } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 6cab0e0239dc4..40938eb642478 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} @@ -264,11 +264,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data"), attr("id")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -377,11 +374,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -875,11 +869,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -963,11 +954,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) val writePartitioningExprs = Seq(attr("data")) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -1154,11 +1142,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) val writePartitioningExprs = Seq(truncateExpr) - val writePartitioning = if (!coalesce) { - clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) - } else { - clusteredWritePartitioning(writePartitioningExprs, Some(1)) - } + val writePartitioning = clusteredWritePartitioning( + writePartitioningExprs, targetNumPartitions, coalesce) checkWriteRequirements( tableDistribution, @@ -1422,6 +1407,9 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase case p: physical.HashPartitioning => val resolvedExprs = p.expressions.map(resolveAttrs(_, plan)) p.copy(expressions = resolvedExprs) + case c: physical.CoalescedHashPartitioning => + val resolvedExprs = c.from.expressions.map(resolveAttrs(_, plan)) + c.copy(from = c.from.copy(expressions = resolvedExprs)) case _: UnknownPartitioning => // don't check partitioning if no particular one is expected actualPartitioning @@ -1480,9 +1468,16 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase private def clusteredWritePartitioning( writePartitioningExprs: Seq[catalyst.expressions.Expression], - targetNumPartitions: Option[Int]): physical.Partitioning = { - HashPartitioning(writePartitioningExprs, - targetNumPartitions.getOrElse(conf.numShufflePartitions)) + targetNumPartitions: Option[Int], + coalesce: Boolean): physical.Partitioning = { + val partitioning = HashPartitioning(writePartitioningExprs, + targetNumPartitions.getOrElse(conf.numShufflePartitions)) + if (coalesce) { + CoalescedHashPartitioning( + partitioning, Seq(CoalescedBoundary(0, partitioning.numPartitions))) + } else { + partitioning + } } private def partitionSizes(dataSkew: Boolean, coalesce: Boolean): Seq[Option[Long]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 7f938deaaa645..ac57c958828b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -646,18 +646,6 @@ class QueryCompilationErrorsSuite parameters = Map("expression" -> "\"(explode(array(1, 2, 3)) + 1)\"")) } - test("UNSUPPORTED_GENERATOR: only one generator allowed") { - val e = intercept[AnalysisException]( - sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2, 3))""").collect() - ) - - checkError( - exception = e, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map("clause" -> "SELECT", "num" -> "2", - "generators" -> "\"explode(array(1, 2, 3))\", \"explode(array(1, 2, 3))\"")) - } - test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT clause") { val e = intercept[AnalysisException]( sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index 7ebb677b12158..84857b972918a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -604,6 +604,13 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL sqlState = "42K01", parameters = Map("elementType" -> ""), context = ExpectedContext(fragment = "ARRAY", start = 30, stop = 34)) + // Create column of array type without specifying element type in lowercase + checkError( + exception = parseException("CREATE TABLE tbl_120691 (col1 array)"), + errorClass = "INCOMPLETE_TYPE_DEFINITION.ARRAY", + sqlState = "42K01", + parameters = Map("elementType" -> ""), + context = ExpectedContext(fragment = "array", start = 30, stop = 34)) } test("INCOMPLETE_TYPE_DEFINITION: struct type definition is incomplete") { @@ -631,6 +638,12 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL errorClass = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'>'", "hint" -> "")) + // Create column of struct type without specifying field type in lowercase + checkError( + exception = parseException("CREATE TABLE tbl_120691 (col1 struct)"), + errorClass = "INCOMPLETE_TYPE_DEFINITION.STRUCT", + sqlState = "42K01", + context = ExpectedContext(fragment = "struct", start = 30, stop = 35)) } test("INCOMPLETE_TYPE_DEFINITION: map type definition is incomplete") { @@ -652,6 +665,12 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL errorClass = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'>'", "hint" -> "")) + // Create column of map type without specifying key/value types in lowercase + checkError( + exception = parseException("SELECT CAST(map('1',2) AS map)"), + errorClass = "INCOMPLETE_TYPE_DEFINITION.MAP", + sqlState = "42K01", + context = ExpectedContext(fragment = "map", start = 26, stop = 28)) } test("INVALID_ESC: Escape string must contain only one character") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index 24a98dd83f33a..e11191da6a952 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -310,6 +310,67 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite { } } + test("SPARK-46590 adaptive query execution works correctly with broadcast join and union") { + val test: SparkSession => Unit = { spark: SparkSession => + import spark.implicits._ + spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "1KB") + spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key, "10KB") + spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR, 2.0) + val df00 = spark.range(0, 1000, 2) + .selectExpr("id as key", "id as value") + .union(Seq.fill(100000)((600, 600)).toDF("key", "value")) + val df01 = spark.range(0, 1000, 3) + .selectExpr("id as key", "id as value") + val df10 = spark.range(0, 1000, 5) + .selectExpr("id as key", "id as value") + .union(Seq.fill(500000)((600, 600)).toDF("key", "value")) + val df11 = spark.range(0, 1000, 7) + .selectExpr("id as key", "id as value") + val df20 = spark.range(0, 10).selectExpr("id as key", "id as value") + + df20.join(df00.join(df01, Array("key", "value"), "left_outer") + .union(df10.join(df11, Array("key", "value"), "left_outer"))) + .write + .format("noop") + .mode("overwrite") + .save() + } + withSparkSession(test, 12000, None) + } + + test("SPARK-46590 adaptive query execution works correctly with cartesian join and union") { + val test: SparkSession => Unit = { spark: SparkSession => + import spark.implicits._ + spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") + spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key, "100B") + spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR, 2.0) + val df00 = spark.range(0, 10, 2) + .selectExpr("id as key", "id as value") + .union(Seq.fill(1000)((600, 600)).toDF("key", "value")) + val df01 = spark.range(0, 10, 3) + .selectExpr("id as key", "id as value") + val df10 = spark.range(0, 10, 5) + .selectExpr("id as key", "id as value") + .union(Seq.fill(5000)((600, 600)).toDF("key", "value")) + val df11 = spark.range(0, 10, 7) + .selectExpr("id as key", "id as value") + val df20 = spark.range(0, 10) + .selectExpr("id as key", "id as value") + .union(Seq.fill(1000)((11, 11)).toDF("key", "value")) + val df21 = spark.range(0, 10) + .selectExpr("id as key", "id as value") + + df20.join(df21.hint("shuffle_hash"), Array("key", "value"), "left_outer") + .join(df00.join(df01.hint("shuffle_hash"), Array("key", "value"), "left_outer") + .union(df10.join(df11.hint("shuffle_hash"), Array("key", "value"), "left_outer"))) + .write + .format("noop") + .mode("overwrite") + .save() + } + withSparkSession(test, 100, None) + } + test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { val test: SparkSession => Unit = { spark: SparkSession => spark.sql("SET spark.sql.exchange.reuse=true") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index f5839e9975602..ec13d48d45f84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection, UnknownPartitioning} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType class ProjectedOrderingAndPartitioningSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { @@ -101,6 +104,22 @@ class ProjectedOrderingAndPartitioningSuite } } + test("SPARK-46609: Avoid exponential explosion in PartitioningPreservingUnaryExecNode") { + withSQLConf(SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT.key -> "2") { + val output = Seq(AttributeReference("a", StringType)(), AttributeReference("b", StringType)()) + val plan = ProjectExec( + Seq( + Alias(output(0), "a1")(), + Alias(output(0), "a2")(), + Alias(output(1), "b1")(), + Alias(output(1), "b2")() + ), + DummyLeafPlanExec(output) + ) + assert(plan.outputPartitioning.asInstanceOf[PartitioningCollection].partitionings.length == 2) + } + } + test("SPARK-42049: Improve AliasAwareOutputExpression - multi-references to complex " + "expressions") { val df2 = spark.range(2).repartition($"id" + $"id").selectExpr("id + id as a", "id + id as b") @@ -192,3 +211,10 @@ class ProjectedOrderingAndPartitioningSuite assert(outputOrdering.head.sameOrderExpressions.size == 0) } } + +private case class DummyLeafPlanExec(output: Seq[Attribute]) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = null + override def outputPartitioning: Partitioning = { + PartitioningCollection(output.map(attr => HashPartitioning(Seq(attr), 4))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index e258d600a2aa8..a1147c16cc861 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -1216,4 +1216,18 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } } + + test("Inline table with current time expression") { + withView("v1") { + sql("CREATE VIEW v1 (t1, t2) AS SELECT * FROM VALUES (now(), now())") + val r1 = sql("select t1, t2 from v1").collect()(0) + val ts1 = (r1.getTimestamp(0), r1.getTimestamp(1)) + assert(ts1._1 == ts1._2) + Thread.sleep(1) + val r2 = sql("select t1, t2 from v1").collect()(0) + val ts2 = (r2.getTimestamp(0), r2.getTimestamp(1)) + assert(ts2._1 == ts2._2) + assert(ts1._1.getTime < ts2._1.getTime) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala index da05373125d31..f8b7964368476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala @@ -567,14 +567,13 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { } { - // Assertion error if shuffle partition specs contain `CoalescedShuffleSpec` that has - // `end` - `start` > 1. + // If shuffle partition specs contain `CoalescedShuffleSpec` that has + // `end` - `start` > 1, return empty result. val bytesByPartitionId1 = Array[Long](10, 10, 10, 10, 10) val bytesByPartitionId2 = Array[Long](10, 10, 10, 10, 10) val specs1 = Seq(CoalescedPartitionSpec(0, 1), CoalescedPartitionSpec(1, 5)) val specs2 = specs1 - intercept[AssertionError] { - ShufflePartitionsUtil.coalescePartitions( + val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), Some(new MapOutputStatistics(1, bytesByPartitionId2))), @@ -582,17 +581,16 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { Some(specs1), Some(specs2)), targetSize, 1, 0) - } + assert(coalesced.isEmpty) } { - // Assertion error if shuffle partition specs contain `PartialMapperShuffleSpec`. + // If shuffle partition specs contain `PartialMapperShuffleSpec`, return empty result. val bytesByPartitionId1 = Array[Long](10, 10, 10, 10, 10) val bytesByPartitionId2 = Array[Long](10, 10, 10, 10, 10) val specs1 = Seq(CoalescedPartitionSpec(0, 1), PartialMapperPartitionSpec(1, 0, 1)) val specs2 = specs1 - intercept[AssertionError] { - ShufflePartitionsUtil.coalescePartitions( + val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), Some(new MapOutputStatistics(1, bytesByPartitionId2))), @@ -600,18 +598,17 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { Some(specs1), Some(specs2)), targetSize, 1, 0) - } + assert(coalesced.isEmpty) } { - // Assertion error if partition specs of different shuffles have different lengths. + // If partition specs of different shuffles have different lengths, return empty result. val bytesByPartitionId1 = Array[Long](10, 10, 10, 10, 10) val bytesByPartitionId2 = Array[Long](10, 10, 10, 10, 10) val specs1 = Seq.tabulate(4)(i => CoalescedPartitionSpec(i, i + 1)) ++ Seq.tabulate(2)(i => PartialReducerPartitionSpec(4, i, i + 1, 10L)) val specs2 = Seq.tabulate(5)(i => CoalescedPartitionSpec(i, i + 1)) - intercept[AssertionError] { - ShufflePartitionsUtil.coalescePartitions( + val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), Some(new MapOutputStatistics(1, bytesByPartitionId2))), @@ -619,11 +616,12 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { Some(specs1), Some(specs2)), targetSize, 1, 0) - } + assert(coalesced.isEmpty) } { - // Assertion error if start indices of partition specs are not identical among all shuffles. + // If start indices of partition specs are not identical among all shuffles, + // return empty result. val bytesByPartitionId1 = Array[Long](10, 10, 10, 10, 10) val bytesByPartitionId2 = Array[Long](10, 10, 10, 10, 10) val specs1 = Seq.tabulate(4)(i => CoalescedPartitionSpec(i, i + 1)) ++ @@ -631,8 +629,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val specs2 = Seq.tabulate(2)(i => CoalescedPartitionSpec(i, i + 1)) ++ Seq.tabulate(2)(i => PartialReducerPartitionSpec(2, i, i + 1, 10L)) ++ Seq.tabulate(2)(i => CoalescedPartitionSpec(i + 3, i + 4)) - intercept[AssertionError] { - ShufflePartitionsUtil.coalescePartitions( + val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), Some(new MapOutputStatistics(1, bytesByPartitionId2))), @@ -640,7 +637,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { Some(specs1), Some(specs2)), targetSize, 1, 0) - } + assert(coalesced.isEmpty) } { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index d949342106159..928d732f2a160 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -130,7 +130,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { assert(sorter.numSpills > 0) // Merging spilled files should not throw assertion error - sorter.writePartitionedMapOutput(0, 0, mapOutputWriter) + sorter.writePartitionedMapOutput(0, 0, mapOutputWriter, + taskContext.taskMetrics.shuffleWriteMetrics) } test("SPARK-10403: unsafe row serializer with SortShuffleManager") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala index e253de76221ad..69145d890fc19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.deploy.SparkSubmitTestUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.functions.{array, col, count, lit} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.unsafe.Platform @@ -70,39 +71,41 @@ class WholeStageCodegenSparkSubmitSuite extends SparkSubmitTestUtils object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging { - var spark: SparkSession = _ - def main(args: Array[String]): Unit = { TestUtils.configTestLog4j2("INFO") - spark = SparkSession.builder().getOrCreate() + val spark = SparkSession.builder() + .config(SQLConf.SHUFFLE_PARTITIONS.key, "2") + .getOrCreate() + + try { + // Make sure the test is run where the driver and the executors uses different object layouts + val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET + val executorArrayHeaderSize = + spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect().head + assert(driverArrayHeaderSize > executorArrayHeaderSize) - // Make sure the test is run where the driver and the executors uses different object layouts - val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET - val executorArrayHeaderSize = - spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt - assert(driverArrayHeaderSize > executorArrayHeaderSize) + val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v") + .groupBy(array(col("v"))).agg(count(col("*"))) + val plan = df.queryExecution.executedPlan + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec])) - val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v") - .groupBy(array(col("v"))).agg(count(col("*"))) - val plan = df.queryExecution.executedPlan - assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec])) + val expectedAnswer = + Row(Array(0), 7178) :: + Row(Array(1), 7178) :: + Row(Array(2), 7178) :: + Row(Array(3), 7177) :: + Row(Array(4), 7177) :: + Row(Array(5), 7177) :: + Row(Array(6), 7177) :: + Row(Array(7), 7177) :: + Row(Array(8), 7177) :: + Row(Array(9), 7177) :: Nil - val expectedAnswer = - Row(Array(0), 7178) :: - Row(Array(1), 7178) :: - Row(Array(2), 7178) :: - Row(Array(3), 7177) :: - Row(Array(4), 7177) :: - Row(Array(5), 7177) :: - Row(Array(6), 7177) :: - Row(Array(7), 7177) :: - Row(Array(8), 7177) :: - Row(Array(9), 7177) :: Nil - val result = df.collect - QueryTest.sameRows(result.toSeq, expectedAnswer) match { - case Some(errMsg) => fail(errMsg) - case _ => + QueryTest.checkAnswer(df, expectedAnswer) + } finally { + spark.stop() } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 68bae34790a00..f6b96ee7e1ebd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -26,12 +26,13 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec @@ -60,7 +61,8 @@ class AdaptiveQueryExecSuite setupTestData() - private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { + private def runAdaptiveAndVerifyResult(query: String, + skipCheckAnswer: Boolean = false): (SparkPlan, SparkPlan) = { var finalPlanCnt = 0 var hasMetricsEvent = false val listener = new SparkListener { @@ -84,8 +86,10 @@ class AdaptiveQueryExecSuite assert(planBefore.toString.startsWith("AdaptiveSparkPlan isFinalPlan=false")) val result = dfAdaptive.collect() withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { - val df = sql(query) - checkAnswer(df, result) + if (!skipCheckAnswer) { + val df = sql(query) + checkAnswer(df, result) + } } val planAfter = dfAdaptive.queryExecution.executedPlan assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) @@ -157,6 +161,12 @@ class AdaptiveQueryExecSuite } } + private def findTopLevelUnion(plan: SparkPlan): Seq[UnionExec] = { + collect(plan) { + case l: UnionExec => l + } + } + private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { collectWithSubqueries(plan) { case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e @@ -2405,6 +2415,28 @@ class AdaptiveQueryExecSuite } } + test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " + + "resulting in potentially inaccurate data") { + withTable("t3") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> (SortShuffleManager + .MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1).toString) { + sql("CREATE TABLE t3 USING PARQUET AS SELECT id FROM range(2)") + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT id, count(*) + |FROM t3 + |GROUP BY id + |LIMIT 1 + |""".stripMargin, skipCheckAnswer = true) + // The shuffle stage produces two rows and the limit operator should not been optimized out. + assert(findTopLevelLimit(plan).size == 1) + assert(findTopLevelLimit(adaptivePlan).size == 1) + } + } + } + test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") { withTempView("v") { withSQLConf( @@ -2675,6 +2707,67 @@ class AdaptiveQueryExecSuite } } + test("SPARK-48155: AQEPropagateEmptyRelation check remained child for join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + // Before SPARK-48155, since the AQE will call ValidateSparkPlan, + // all AQE optimize rule won't work and return the origin plan. + // After SPARK-48155, Spark avoid invalid propagate of empty relation. + // Then the UNION first child empty relation can be propagate correctly + // and the JOIN won't be propagated since will generated a invalid plan. + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1 + |INNER JOIN ( + | SELECT * FROM testData2 + | WHERE b = 0 + | UNION ALL + | SELECT * FROM testData2 + | WHErE b != 0 + |) t2 + |ON t1.b = t2.b AND t1.a = 0 + |RIGHT OUTER JOIN testData2 t3 + |ON t1.a > t3.a + |GROUP BY t3.b + """.stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + assert(findTopLevelUnion(adaptivePlan).size == 0) + } + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF().createOrReplaceTempView("t1") + spark.range(100).createOrReplaceTempView("t2") + spark.range(2).createOrReplaceTempView("t3") + spark.range(2).createOrReplaceTempView("t4") + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT tt2.value + |FROM ( + | SELECT value + | FROM t1 + | WHERE NOT EXISTS ( + | SELECT 1 + | FROM ( + | SELECT t2.id + | FROM t2 + | JOIN t3 ON t2.id = t3.id + | AND t2.id > 100 + | ) tt + | WHERE t1.value = tt.id + | ) + | AND t1.value = 1 + |) tt2 + | LEFT JOIN t4 ON tt2.value = t4.id + |""".stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + } + } + } + test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") { // partitioning: HashPartitioning @@ -2758,7 +2851,7 @@ class AdaptiveQueryExecSuite case s: SortExec => s }.size == (if (firstAccess) 2 else 0)) assert(collect(initialExecutedPlan) { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i }.head.isMaterialized != firstAccess) df.collect() @@ -2770,7 +2863,7 @@ class AdaptiveQueryExecSuite case s: SortExec => s }.isEmpty) assert(collect(initialExecutedPlan) { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i }.head.isMaterialized) } @@ -2851,6 +2944,27 @@ class AdaptiveQueryExecSuite val unionDF = aggDf1.union(aggDf2) checkAnswer(unionDF.select("id").distinct, Seq(Row(null))) } + + test("SPARK-49979: AQE hang forever when collecting twice on a failed AQE plan") { + val func: Long => Boolean = (i : Long) => { + throw new Exception("SPARK-49979") + } + withUserDefinedFunction("func" -> true) { + spark.udf.register("func", func) + val df1 = spark.range(1024).select($"id".as("key1")) + val df2 = spark.range(2048).select($"id".as("key2")) + .withColumn("group_key", $"key2" % 1024) + val df = df1.filter(expr("func(key1)")).hint("MERGE").join(df2, $"key1" === $"key2") + .groupBy($"group_key").agg("key1" -> "count") + intercept[Throwable] { + df.collect() + } + // second collect should not hang forever + intercept[Throwable] { + df.collect() + } + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala index 55d9fb2731799..1f132dabd2878 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.execution.ColumnarToRowExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark /** @@ -33,11 +34,11 @@ import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark * Results will be written to "benchmarks/InMemoryColumnarBenchmark-results.txt". * }}} */ -object InMemoryColumnarBenchmark extends SqlBasedBenchmark { +object InMemoryColumnarBenchmark extends SqlBasedBenchmark with AdaptiveSparkPlanHelper { def intCache(rowsNum: Long, numIters: Int): Unit = { val data = spark.range(0, rowsNum, 1, 1).toDF("i").cache() - val inMemoryScan = data.queryExecution.executedPlan.collect { + val inMemoryScan = collect(data.queryExecution.executedPlan) { case m: InMemoryTableScanExec => m } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala index 72b3a4bc1095a..2c73622739a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryRelationSuite.scala @@ -18,20 +18,42 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.test.SharedSparkSessionBase import org.apache.spark.storage.StorageLevel -class InMemoryRelationSuite extends SparkFunSuite with SharedSparkSessionBase { - test("SPARK-43157: Clone innerChildren cached plan") { +class InMemoryRelationSuite extends SparkFunSuite + with SharedSparkSessionBase with AdaptiveSparkPlanHelper { + + test("SPARK-46779: InMemoryRelations with the same cached plan are semantically equivalent") { val d = spark.range(1) - val relation = InMemoryRelation(StorageLevel.MEMORY_ONLY, d.queryExecution, None) - val cloned = relation.clone().asInstanceOf[InMemoryRelation] + val r1 = InMemoryRelation(StorageLevel.MEMORY_ONLY, d.queryExecution, None) + val r2 = r1.withOutput(r1.output.map(_.newInstance())) + assert(r1.sameResult(r2)) + } + + test("SPARK-47177: Cached SQL plan do not display final AQE plan in explain string") { + def findIMRInnerChild(p: SparkPlan): SparkPlan = { + val tableCache = find(p) { + case _: InMemoryTableScanExec => true + case _ => false + } + assert(tableCache.isDefined) + tableCache.get.asInstanceOf[InMemoryTableScanExec].relation.innerChildren.head + } - val relationCachedPlan = relation.innerChildren.head - val clonedCachedPlan = cloned.innerChildren.head + val d1 = spark.range(1).withColumn("key", expr("id % 100")) + .groupBy("key").agg(Map("key" -> "count")) + val cached_d2 = d1.cache() + val df = cached_d2.withColumn("key2", expr("key % 10")) + .groupBy("key2").agg(Map("key2" -> "count")) - // verify the plans are not the same object but are logically equivalent - assert(!relationCachedPlan.eq(clonedCachedPlan)) - assert(relationCachedPlan === clonedCachedPlan) + assert(findIMRInnerChild(df.queryExecution.executedPlan).treeString + .contains("AdaptiveSparkPlan isFinalPlan=false")) + df.collect() + assert(findIMRInnerChild(df.queryExecution.executedPlan).treeString + .contains("AdaptiveSparkPlan isFinalPlan=true")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala index a2f3d872a68e9..2979d3cdcab56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.internal.SQLConf @@ -160,6 +160,8 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn("cat") newCatalog } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index dea66bb09cfac..1124184dded7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER import org.apache.spark.sql.internal.SQLConf @@ -47,6 +48,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { try { // drop all databases, tables and functions after each test spark.sessionState.catalog.reset() + spark.sessionState.catalogManager.reset() } finally { Utils.deleteRecursively(new File(spark.sessionState.conf.warehousePath)) super.afterEach() @@ -218,7 +220,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { test("SPARK-25403 refresh the table after inserting data") { withTable("t") { val catalog = spark.sessionState.catalog - val table = QualifiedTableName(catalog.getCurrentDatabase, "t") + val table = QualifiedTableName( + CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "t") sql("CREATE TABLE t (a INT) USING parquet") sql("INSERT INTO TABLE t VALUES (1)") assert(catalog.getCachedTable(table) === null, "Table relation should be invalidated.") @@ -231,7 +234,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { withTable("t") { withTempDir { dir => val catalog = spark.sessionState.catalog - val table = QualifiedTableName(catalog.getCurrentDatabase, "t") + val table = QualifiedTableName( + CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "t") val p1 = s"${dir.getCanonicalPath}/p1" val p2 = s"${dir.getCanonicalPath}/p2" sql(s"CREATE TABLE t (a INT) USING parquet LOCATION '$p1'") @@ -1120,7 +1124,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") assert(getProps == Map("andrew" -> "or14", "kor" -> "bel")) // set table properties without explicitly specifying database - catalog.setCurrentDatabase("dbx") + spark.sessionState.catalogManager.setCurrentNamespace(Array("dbx")) sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") assert(getProps == Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) // table to alter does not exist @@ -1154,7 +1158,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") assert(getProps == Map("p" -> "an", "c" -> "lan", "x" -> "y")) // unset table properties without explicitly specifying database - catalog.setCurrentDatabase("dbx") + spark.sessionState.catalogManager.setCurrentNamespace(Array("dbx")) sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") assert(getProps == Map("c" -> "lan", "x" -> "y")) // table to alter does not exist diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 8eb0d5456c111..d738270699bd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCom import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.connector.FakeV2Provider -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Column, ColumnDefaultValue, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Column, ColumnDefaultValue, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, TableWritePrivilege, V1Table} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} @@ -157,6 +157,8 @@ class PlanResolutionSuite extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn("testcat") newCatalog } @@ -174,6 +176,8 @@ class PlanResolutionSuite extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn(CatalogManager.SESSION_CATALOG_NAME) newCatalog } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala index cd0a057284705..747a378275019 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.permission.{AclEntry, AclEntryScope, AclEntryType, F import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.command import org.apache.spark.sql.execution.command.FakeLocalFsFileSystem import org.apache.spark.sql.internal.SQLConf @@ -146,7 +147,8 @@ trait TruncateTableSuiteBase extends command.TruncateTableSuiteBase { spark.table(t) val catalog = spark.sessionState.catalog - val qualifiedTableName = QualifiedTableName("ns", "tbl") + val qualifiedTableName = + QualifiedTableName(CatalogManager.SESSION_CATALOG_NAME, "ns", "tbl") val cachedPlan = catalog.getCachedTable(qualifiedTableName) assert(cachedPlan.stats.sizeInBytes == 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala index e2f2aee56115f..a21baebe24d8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala @@ -175,4 +175,25 @@ class DescribeTableSuite extends command.DescribeTableSuiteBase Row("max_col_len", "NULL"))) } } + + test("SPARK-46535: describe extended (formatted) a column without col stats") { + withNamespaceAndTable("ns", "tbl") { tbl => + sql( + s""" + |CREATE TABLE $tbl + |(key INT COMMENT 'column_comment', col STRING) + |$defaultUsing""".stripMargin) + + val descriptionDf = sql(s"DESCRIBE TABLE EXTENDED $tbl key") + assert(descriptionDf.schema.map(field => (field.name, field.dataType)) === Seq( + ("info_name", StringType), + ("info_value", StringType))) + QueryTest.checkAnswer( + descriptionDf, + Seq( + Row("col_name", "key"), + Row("data_type", "int"), + Row("comment", "column_comment"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala index 06e570cb016b0..90b341ae1f2cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.net.URI + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.PrivateMethodTester @@ -214,4 +216,6 @@ class MockFileSystem extends RawLocalFileSystem { override def globStatus(pathPattern: Path): Array[FileStatus] = { mockGlobResults.getOrElse(pathPattern, Array()) } + + override def getUri: URI = URI.create("mockFs://mockFs/") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCodecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCodecSuite.scala index 09a348cd29451..9f3d6ff48d477 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCodecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCodecSuite.scala @@ -59,7 +59,7 @@ class ParquetCodecSuite extends FileSourceCodecSuite { // Exclude "brotli" because the com.github.rdblue:brotli-codec dependency is not available // on Maven Central. override protected def availableCodecs: Seq[String] = { - Seq("none", "uncompressed", "snappy", "gzip", "zstd", "lz4", "lz4raw") + Seq("none", "uncompressed", "snappy", "gzip", "zstd", "lz4", "lz4raw", "lz4_raw") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 3bd45ca0dcdb3..3762c00ff1a19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -80,6 +80,7 @@ abstract class CSVSuite private val valueMalformedFile = "test-data/value-malformed.csv" private val badAfterGoodFile = "test-data/bad_after_good.csv" private val malformedRowFile = "test-data/malformedRow.csv" + private val charFile = "test-data/char.csv" /** Verifies data and schema. */ private def verifyCars( @@ -1105,10 +1106,12 @@ abstract class CSVSuite test("SPARK-37326: Timestamp type inference for a column with TIMESTAMP_NTZ values") { withTempPath { path => - val exp = spark.sql(""" - select timestamp_ntz'2020-12-12 12:12:12' as col0 union all - select timestamp_ntz'2020-12-12 12:12:12' as col0 - """) + val exp = spark.sql( + """ + |select * + |from values (timestamp_ntz'2020-12-12 12:12:12'), (timestamp_ntz'2020-12-12 12:12:12') + |as t(col0) + |""".stripMargin) exp.write.format("csv").option("header", "true").save(path.getAbsolutePath) @@ -1126,6 +1129,15 @@ abstract class CSVSuite if (timestampType == SQLConf.TimestampTypes.TIMESTAMP_NTZ.toString) { checkAnswer(res, exp) + } else if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { + // When legacy parser is enabled, we can't parse the NTZ string to LTZ, and eventually + // infer string type. + val expected = spark.read + .format("csv") + .option("inferSchema", "false") + .option("header", "true") + .load(path.getAbsolutePath) + checkAnswer(res, expected) } else { checkAnswer( res, @@ -2068,6 +2080,7 @@ abstract class CSVSuite .option("header", true) .option("enforceSchema", false) .option("multiLine", multiLine) + .option("columnPruning", true) .load(dir) .select("columnA"), Row("a")) @@ -2078,6 +2091,7 @@ abstract class CSVSuite .option("header", true) .option("enforceSchema", false) .option("multiLine", multiLine) + .option("columnPruning", true) .load(dir) .count() === 1L) } @@ -2862,13 +2876,12 @@ abstract class CSVSuite test("SPARK-40474: Infer schema for columns with a mix of dates and timestamp") { withTempPath { path => - Seq( - "1765-03-28", + val input = Seq( "1423-11-12T23:41:00", + "1765-03-28", "2016-01-28T20:00:00" - ).toDF() - .repartition(1) - .write.text(path.getAbsolutePath) + ).toDF().repartition(1) + input.write.text(path.getAbsolutePath) if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { val options = Map( @@ -2879,12 +2892,7 @@ abstract class CSVSuite .format("csv") .options(options) .load(path.getAbsolutePath) - val expected = Seq( - Row(Timestamp.valueOf("1765-03-28 00:00:00.0")), - Row(Timestamp.valueOf("1423-11-12 23:41:00.0")), - Row(Timestamp.valueOf("2016-01-28 20:00:00.0")) - ) - checkAnswer(df, expected) + checkAnswer(df, input) } else { // When timestampFormat is specified, infer and parse the column as strings val options1 = Map( @@ -2895,12 +2903,7 @@ abstract class CSVSuite .format("csv") .options(options1) .load(path.getAbsolutePath) - val expected1 = Seq( - Row("1765-03-28"), - Row("1423-11-12T23:41:00"), - Row("2016-01-28T20:00:00") - ) - checkAnswer(df1, expected1) + checkAnswer(df1, input) // When timestampFormat is not specified, infer and parse the column as // timestamp type if possible @@ -3151,7 +3154,7 @@ abstract class CSVSuite } test("SPARK-40667: validate CSV Options") { - assert(CSVOptions.getAllOptions.size == 38) + assert(CSVOptions.getAllOptions.size == 39) // Please add validation on any new CSV options here assert(CSVOptions.isValidOption("header")) assert(CSVOptions.isValidOption("inferSchema")) @@ -3191,6 +3194,7 @@ abstract class CSVSuite assert(CSVOptions.isValidOption("codec")) assert(CSVOptions.isValidOption("sep")) assert(CSVOptions.isValidOption("delimiter")) + assert(CSVOptions.isValidOption("columnPruning")) // Please add validation on any new parquet options with alternative here assert(CSVOptions.getAlternativeOption("sep").contains("delimiter")) assert(CSVOptions.getAlternativeOption("delimiter").contains("sep")) @@ -3200,6 +3204,52 @@ abstract class CSVSuite assert(CSVOptions.getAlternativeOption("codec").contains("compression")) assert(CSVOptions.getAlternativeOption("preferDate").isEmpty) } + + test("SPARK-46862: column pruning in the multi-line mode") { + val data = + """"jobID","Name","City","Active" + |"1","DE","","Yes" + |"5",",","","," + |"3","SA","","No" + |"10","abcd""efgh"" \ndef","","" + |"8","SE","","No"""".stripMargin + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + Seq(true, false).foreach { enforceSchema => + val df = spark.read + .option("multiLine", true) + .option("header", true) + .option("escape", "\"") + .option("enforceSchema", enforceSchema) + .csv(path.getCanonicalPath) + assert(df.count() === 5) + } + } + } + + test("SPARK-48241: CSV parsing failure with char/varchar type columns") { + withTable("charVarcharTable") { + spark.sql( + s""" + |CREATE TABLE charVarcharTable( + | color char(4), + | name varchar(10)) + |USING csv + |OPTIONS ( + | header "true", + | path "${testFile(charFile)}" + |) + """.stripMargin) + val expected = Seq( + Row("pink", "Bob"), + Row("blue", "Mike"), + Row("grey", "Tom")) + checkAnswer( + sql("SELECT * FROM charVarcharTable"), + expected) + } + } } class CSVv1Suite extends CSVSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 11779286ec25f..3bb193cb8f10b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, Da import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.StructType.fromDDL import org.apache.spark.sql.types.TestUDT.{MyDenseVector, MyDenseVectorUDT} @@ -3654,6 +3655,78 @@ abstract class JsonSuite assert(JSONOptions.getAlternativeOption("charset").contains("encoding")) assert(JSONOptions.getAlternativeOption("dateFormat").isEmpty) } + + test("SPARK-47704: Handle partial parsing of array") { + withTempPath { path => + Seq("""{"a":[{"key":{"b":0}}]}""").toDF() + .repartition(1) + .write.text(path.getAbsolutePath) + + for (enablePartialResults <- Seq(true, false)) { + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> s"$enablePartialResults") { + val df = spark.read + .schema("a array>>") + .json(path.getAbsolutePath) + + if (enablePartialResults) { + checkAnswer(df, Seq(Row(Array(Map("key" -> Row(null)))))) + } else { + checkAnswer(df, Seq(Row(null))) + } + } + } + } + } + + test("SPARK-47704: Handle partial parsing of map") { + withTempPath { path => + Seq("""{"a":{"key":[{"b":0}]}}""").toDF() + .repartition(1) + .write.text(path.getAbsolutePath) + + for (enablePartialResults <- Seq(true, false)) { + withSQLConf(SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> s"$enablePartialResults") { + val df = spark.read + .schema("a map>>") + .json(path.getAbsolutePath) + + if (enablePartialResults) { + checkAnswer(df, Seq(Row(Map("key" -> Seq(Row(null)))))) + } else { + checkAnswer(df, Seq(Row(null))) + } + } + } + } + } + + test("SPARK-48965: Dataset#toJSON should use correct schema #1: decimals") { + val numString = "123.456" + val bd = BigDecimal(numString) + val ds1 = sql(s"select ${numString}bd as a, ${numString}bd as b").as[DecimalData] + checkDataset( + ds1, + DecimalData(bd, bd) + ) + val ds2 = ds1.toJSON + checkDataset( + ds2, + "{\"a\":123.456000000000000000,\"b\":123.456000000000000000}" + ) + } + + test("SPARK-48965: Dataset#toJSON should use correct schema #2: misaligned columns") { + val ds1 = sql("select 'Hey there' as value, 90000001 as key").as[TestData] + checkDataset( + ds1, + TestData(90000001, "Hey there") + ) + val ds2 = ds1.toJSON + checkDataset( + ds2, + "{\"key\":90000001,\"value\":\"Hey there\"}" + ) + } } class JsonV1Suite extends JsonSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala index a9389c1c21b40..06ea12f83ce75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala @@ -26,11 +26,12 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.orc.TypeDescription import org.apache.spark.TestUtils +import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.vectorized.ConstantColumnVector +import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, OffHeapColumnVector} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -53,7 +54,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { requestedDataColIds: Array[Int], requestedPartitionColIds: Array[Int], resultFields: Array[StructField]): OrcColumnarBatchReader = { - val reader = new OrcColumnarBatchReader(4096) + val reader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP) reader.initBatch( orcFileSchema, resultFields, @@ -117,7 +118,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, file.length, Array.empty) val taskConf = sqlContext.sessionState.newHadoopConf() val orcFileSchema = TypeDescription.fromString(schema.simpleString) - val vectorizedReader = new OrcColumnarBatchReader(4096) + val vectorizedReader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) @@ -148,4 +149,15 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-46598: off-heap mode") { + val reader = new OrcColumnarBatchReader(4096, MemoryMode.OFF_HEAP) + reader.initBatch( + TypeDescription.fromString("struct"), + StructType.fromDDL("col1 int, col2 int, col3 int").fields, + Array(0, 1, -1), + Array(-1, -1, -1), + InternalRow.empty) + assert(reader.columnarBatch.column(2).isInstanceOf[OffHeapColumnVector]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcEncryptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcEncryptionSuite.scala index b7d29588f6bf4..575f230729ebd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcEncryptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcEncryptionSuite.scala @@ -17,20 +17,52 @@ package org.apache.spark.sql.execution.datasources.orc +import java.util.{Map => JMap} import java.util.Random -import org.apache.orc.impl.HadoopShimsFactory +import scala.collection.mutable +import org.apache.orc.impl.{CryptoUtils, HadoopShimsFactory, KeyProvider} + +import org.apache.spark.SparkConf import org.apache.spark.sql.Row import org.apache.spark.sql.test.SharedSparkSession class OrcEncryptionSuite extends OrcTest with SharedSparkSession { import testImplicits._ + override def sparkConf: SparkConf = { + super.sparkConf.set("spark.hadoop.hadoop.security.key.provider.path", "test:///") + } + + override def beforeAll(): Unit = { + // Backup `CryptoUtils#keyProviderCache` and clear it. + keyProviderCacheRef.entrySet() + .forEach(e => keyProviderCacheBackup.put(e.getKey, e.getValue)) + keyProviderCacheRef.clear() + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + // Restore `CryptoUtils#keyProviderCache`. + keyProviderCacheRef.clear() + keyProviderCacheBackup.foreach { case (k, v) => keyProviderCacheRef.put(k, v) } + } + val originalData = Seq(("123456789", "dongjoon@apache.org", "Dongjoon Hyun")) val rowDataWithoutKey = Row(null, "841626795E7D351555B835A002E3BF10669DE9B81C95A3D59E10865AC37EA7C3", "Dongjoon Hyun") + private val keyProviderCacheBackup: mutable.Map[String, KeyProvider] = mutable.Map.empty + + private val keyProviderCacheRef: JMap[String, KeyProvider] = { + val clazz = classOf[CryptoUtils] + val field = clazz.getDeclaredField("keyProviderCache") + field.setAccessible(true) + field.get(null).asInstanceOf[JMap[String, KeyProvider]] + } + test("Write and read an encrypted file") { val conf = spark.sessionState.newHadoopConf() val provider = HadoopShimsFactory.get.getHadoopKeyProvider(conf, new Random) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala index ac0aad16f1eba..27e2816ce9d94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -29,9 +29,23 @@ import org.apache.spark.sql.test.SharedSparkSession class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSparkSession { test("Test `spark.sql.parquet.compression.codec` config") { - Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO", "LZ4", "BROTLI", "ZSTD").foreach { c => + Seq( + "NONE", + "UNCOMPRESSED", + "SNAPPY", + "GZIP", + "LZO", + "LZ4", + "BROTLI", + "ZSTD", + "LZ4RAW", + "LZ4_RAW").foreach { c => withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { - val expected = if (c == "NONE") "UNCOMPRESSED" else c + val expected = c match { + case "NONE" => "UNCOMPRESSED" + case "LZ4RAW" => "LZ4_RAW" + case other => other + } val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) assert(option.compressionCodecClassName == expected) } @@ -97,7 +111,10 @@ class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSpar createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) val partitionPath = if (isPartitioned) "p=2" else "" val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" - val realCompressionCodecs = getTableCompressionCodec(path) + val realCompressionCodecs = getTableCompressionCodec(path).map { + case "LZ4_RAW" if compressionCodec == "LZ4RAW" => "LZ4RAW" + case other => other + } assert(realCompressionCodecs.forall(_ == compressionCodec)) } } @@ -105,7 +122,7 @@ class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSpar test("Create parquet table with compression") { Seq(true, false).foreach { isPartitioned => - val codecs = Seq("UNCOMPRESSED", "SNAPPY", "GZIP", "ZSTD", "LZ4") + val codecs = Seq("UNCOMPRESSED", "SNAPPY", "GZIP", "ZSTD", "LZ4", "LZ4RAW", "LZ4_RAW") codecs.foreach { compressionCodec => checkCompressionCodec(compressionCodec, isPartitioned) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 269a3efb7360c..8e88049f51e10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -901,6 +902,76 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } + test("don't push down filters that would result in overflows") { + val schema = StructType(Seq( + StructField("cbyte", ByteType), + StructField("cshort", ShortType), + StructField("cint", IntegerType) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + val parquetFilters = createParquetFilters(parquetSchema) + + for { + column <- Seq("cbyte", "cshort", "cint") + value <- Seq(JLong.MAX_VALUE, JLong.MIN_VALUE).map(JLong.valueOf) + } { + val filters = Seq( + sources.LessThan(column, value), + sources.LessThanOrEqual(column, value), + sources.GreaterThan(column, value), + sources.GreaterThanOrEqual(column, value), + sources.EqualTo(column, value), + sources.EqualNullSafe(column, value), + sources.Not(sources.EqualTo(column, value)), + sources.In(column, Array(value)) + ) + for (filter <- filters) { + assert(parquetFilters.createFilter(filter).isEmpty, + s"Row group filter $filter shouldn't be pushed down.") + } + } + } + + test("don't push down filters when value type doesn't match column type") { + val schema = StructType(Seq( + StructField("cbyte", ByteType), + StructField("cshort", ShortType), + StructField("cint", IntegerType), + StructField("clong", LongType), + StructField("cfloat", FloatType), + StructField("cdouble", DoubleType), + StructField("cboolean", BooleanType), + StructField("cstring", StringType), + StructField("cdate", DateType), + StructField("ctimestamp", TimestampType), + StructField("cbinary", BinaryType), + StructField("cdecimal", DecimalType(10, 0)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + val parquetFilters = createParquetFilters(parquetSchema) + + val filters = Seq( + sources.LessThan("cbyte", String.valueOf("1")), + sources.LessThan("cshort", JBigDecimal.valueOf(1)), + sources.LessThan("cint", JFloat.valueOf(JFloat.NaN)), + sources.LessThan("clong", String.valueOf("1")), + sources.LessThan("cfloat", JDouble.valueOf(1.0D)), + sources.LessThan("cdouble", JFloat.valueOf(1.0F)), + sources.LessThan("cboolean", String.valueOf("true")), + sources.LessThan("cstring", Integer.valueOf(1)), + sources.LessThan("cdate", Timestamp.valueOf("2018-01-01 00:00:00")), + sources.LessThan("ctimestamp", Date.valueOf("2018-01-01")), + sources.LessThan("cbinary", Integer.valueOf(1)), + sources.LessThan("cdecimal", Integer.valueOf(1234)) + ) + for (filter <- filters) { + assert(parquetFilters.createFilter(filter).isEmpty, + s"Row group filter $filter shouldn't be pushed down.") + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2e7b26126d24f..29cb224c8787c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -160,21 +160,27 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } - test("SPARK-36182: writing and reading TimestampNTZType column") { - withTable("ts") { - sql("create table ts (c1 timestamp_ntz) using parquet") - sql("insert into ts values (timestamp_ntz'2016-01-01 10:11:12.123456')") - sql("insert into ts values (null)") - sql("insert into ts values (timestamp_ntz'1965-01-01 10:11:12.123456')") - val expectedSchema = new StructType().add(StructField("c1", TimestampNTZType)) - assert(spark.table("ts").schema == expectedSchema) - val expected = Seq( - ("2016-01-01 10:11:12.123456"), - (null), - ("1965-01-01 10:11:12.123456")) - .toDS().select($"value".cast("timestamp_ntz")) - withAllParquetReaders { - checkAnswer(sql("select * from ts"), expected) + test("SPARK-36182, SPARK-47368: writing and reading TimestampNTZType column") { + Seq("true", "false").foreach { inferNTZ => + // The SQL Conf PARQUET_INFER_TIMESTAMP_NTZ_ENABLED should not affect the file written + // by Spark. + withSQLConf(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key -> inferNTZ) { + withTable("ts") { + sql("create table ts (c1 timestamp_ntz) using parquet") + sql("insert into ts values (timestamp_ntz'2016-01-01 10:11:12.123456')") + sql("insert into ts values (null)") + sql("insert into ts values (timestamp_ntz'1965-01-01 10:11:12.123456')") + val expectedSchema = new StructType().add(StructField("c1", TimestampNTZType)) + assert(spark.table("ts").schema == expectedSchema) + val expected = Seq( + ("2016-01-01 10:11:12.123456"), + (null), + ("1965-01-01 10:11:12.123456")) + .toDS().select($"value".cast("timestamp_ntz")) + withAllParquetReaders { + checkAnswer(sql("select * from ts"), expected) + } + } } } } @@ -255,6 +261,18 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } + test("SPARK-46466: write and read TimestampNTZ with legacy rebase mode") { + withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> "LEGACY") { + withTable("ts") { + sql("create table ts (c1 timestamp_ntz) using parquet") + sql("insert into ts values (timestamp_ntz'0900-01-01 01:10:10')") + withAllParquetReaders { + checkAnswer(spark.table("ts"), sql("select timestamp_ntz'0900-01-01 01:10:10'")) + } + } + } + } + test("Enabling/disabling merging partfiles when merging parquet schema") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => @@ -1095,6 +1113,26 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } + test("row group skipping doesn't overflow when reading into larger type") { + withTempPath { path => + Seq(0).toDF("a").write.parquet(path.toString) + // The vectorized and non-vectorized readers will produce different exceptions, we don't need + // to test both as this covers row group skipping. + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + // Reading integer 'a' as a long isn't supported. Check that an exception is raised instead + // of incorrectly skipping the single row group and producing incorrect results. + val exception = intercept[SparkException] { + spark.read + .schema("a LONG") + .parquet(path.toString) + .where(s"a < ${Long.MaxValue}") + .collect() + } + assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + } + } + } + test("SPARK-36825, SPARK-36852: create table with ANSI intervals") { withTable("tbl") { sql("create table tbl (c1 interval day, c2 interval year to month) using parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 30f46a3cac2d3..3f47c5e506ffd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -996,6 +996,27 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + test("SPARK-45346: merge schema should respect case sensitivity") { + import testImplicits._ + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withTempPath { path => + Seq(1).toDF("col").write.mode("append").parquet(path.getCanonicalPath) + Seq(2).toDF("COL").write.mode("append").parquet(path.getCanonicalPath) + val df = spark.read.option("mergeSchema", "true").parquet(path.getCanonicalPath) + if (caseSensitive) { + assert(df.columns.toSeq.sorted == Seq("COL", "col")) + assert(df.collect().length == 2) + } else { + // The final column name depends on which file is listed first, and is a bit random. + assert(df.columns.toSeq.map(_.toLowerCase(java.util.Locale.ROOT)) == Seq("col")) + assert(df.collect().length == 2) + } + } + } + } + } + // ======================================= // Tests for parquet schema mismatch error // ======================================= @@ -1066,6 +1087,27 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + test("SPARK-45604: schema mismatch failure error on timestamp_ntz to array") { + import testImplicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + val timestamp = java.time.LocalDateTime.of(1, 2, 3, 4, 5) + val df1 = Seq((1, timestamp)).toDF() + val df2 = Seq((2, Array(timestamp))).toDF() + df1.write.mode("overwrite").parquet(s"$path/parquet") + df2.write.mode("append").parquet(s"$path/parquet") + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + val e = intercept[SparkException] { + spark.read.schema(df2.schema).parquet(s"$path/parquet").collect() + } + assert(e.getCause.isInstanceOf[SparkException]) + assert(e.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + } + } + } + test("SPARK-40819: parquet file with TIMESTAMP(NANOS, true) (with nanosAsLong=true)") { val tsAttribute = "birthday" withSQLConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactorySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactorySuite.scala new file mode 100644 index 0000000000000..bd20307974416 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactorySuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.mockito.Mockito._ +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.execution.datasources.WriteJobDescription +import org.apache.spark.util.SerializableConfiguration + +class FileWriterFactorySuite extends SparkFunSuite with PrivateMethodTester { + + test("SPARK-48484: V2Write uses different TaskAttemptIds for different task attempts") { + val jobDescription = mock(classOf[WriteJobDescription]) + when(jobDescription.serializableHadoopConf).thenReturn( + new SerializableConfiguration(new Configuration(false))) + val committer = mock(classOf[FileCommitProtocol]) + + val writerFactory = FileWriterFactory(jobDescription, committer) + val createTaskAttemptContext = + PrivateMethod[TaskAttemptContextImpl](Symbol("createTaskAttemptContext")) + + val attemptContext = + writerFactory.invokePrivate(createTaskAttemptContext(0, 1)) + val attemptContext1 = + writerFactory.invokePrivate(createTaskAttemptContext(0, 2)) + assert(attemptContext.getTaskAttemptID.getTaskID == attemptContext1.getTaskAttemptID.getTaskID) + assert(attemptContext.getTaskAttemptID.getId != attemptContext1.getTaskAttemptID.getId) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index a5fee51dc916f..4a8a231cc54ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -315,7 +315,7 @@ class V2PredicateSuite extends SparkFunSuite { Array[Expression](ref("a"), literal)) assert(predicate1.equals(predicate2)) assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) - assert(predicate1.describe.equals("a LIKE 'str%'")) + assert(predicate1.describe.equals(raw"a LIKE 'str%' ESCAPE '\'")) val v1Filter = StringStartsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) @@ -332,7 +332,7 @@ class V2PredicateSuite extends SparkFunSuite { Array[Expression](ref("a"), literal)) assert(predicate1.equals(predicate2)) assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) - assert(predicate1.describe.equals("a LIKE '%str'")) + assert(predicate1.describe.equals(raw"a LIKE '%str' ESCAPE '\'")) val v1Filter = StringEndsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) @@ -349,7 +349,7 @@ class V2PredicateSuite extends SparkFunSuite { Array[Expression](ref("a"), literal)) assert(predicate1.equals(predicate2)) assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) - assert(predicate1.describe.equals("a LIKE '%str%'")) + assert(predicate1.describe.equals(raw"a LIKE '%str%' ESCAPE '\'")) val v1Filter = StringContains("a", "str") assert(v1Filter.toV2.equals(predicate1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 8f5996438e202..01033cd681b73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -125,7 +125,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(!catalog.tableExists(testIdent)) - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) assert(parsed == Seq("db", "test_table")) @@ -143,7 +144,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(!catalog.tableExists(testIdent)) - val table = catalog.createTable(testIdent, schema, emptyTrans, properties) + catalog.createTable(testIdent, schema, emptyTrans, properties) + val table = catalog.loadTable(testIdent) val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) assert(parsed == Seq("db", "test_table")) @@ -158,7 +160,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(!catalog.tableExists(testIdent)) - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) .map(part => quoteIdentifier(part)).mkString(".") @@ -185,26 +188,30 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(!catalog.tableExists(testIdent)) // default location - val t1 = catalog.createTable(testIdent, schema, emptyTrans, properties).asInstanceOf[V1Table] + catalog.createTable(testIdent, schema, emptyTrans, properties) + val t1 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t1.catalogTable.location === spark.sessionState.catalog.defaultTablePath(testIdent.asTableIdentifier)) catalog.dropTable(testIdent) // relative path properties.put(TableCatalog.PROP_LOCATION, "relative/path") - val t2 = catalog.createTable(testIdent, schema, emptyTrans, properties).asInstanceOf[V1Table] + catalog.createTable(testIdent, schema, emptyTrans, properties) + val t2 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t2.catalogTable.location === makeQualifiedPathWithWarehouse("db.db/relative/path")) catalog.dropTable(testIdent) // absolute path without scheme properties.put(TableCatalog.PROP_LOCATION, "/absolute/path") - val t3 = catalog.createTable(testIdent, schema, emptyTrans, properties).asInstanceOf[V1Table] + catalog.createTable(testIdent, schema, emptyTrans, properties) + val t3 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t3.catalogTable.location.toString === "file:///absolute/path") catalog.dropTable(testIdent) // absolute path with scheme properties.put(TableCatalog.PROP_LOCATION, "file:/absolute/path") - val t4 = catalog.createTable(testIdent, schema, emptyTrans, properties).asInstanceOf[V1Table] + catalog.createTable(testIdent, schema, emptyTrans, properties) + val t4 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t4.catalogTable.location.toString === "file:/absolute/path") catalog.dropTable(testIdent) } @@ -226,12 +233,11 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("loadTable") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) val loaded = catalog.loadTable(testIdent) - assert(table.name == loaded.name) - assert(table.schema == loaded.schema) - assert(table.properties == loaded.properties) + assert(loaded.name == testIdent.toString) + assert(loaded.schema == schema) } test("loadTable: table does not exist") { @@ -247,7 +253,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("invalidateTable") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) catalog.invalidateTable(testIdent) val loaded = catalog.loadTable(testIdent) @@ -268,11 +275,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add property") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(filterV2TableProperties(table.properties) == Map()) - val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-1", "1")) + catalog.alterTable(testIdent, TableChange.setProperty("prop-1", "1")) + val updated = catalog.loadTable(testIdent) assert(filterV2TableProperties(updated.properties) == Map("prop-1" -> "1")) val loaded = catalog.loadTable(testIdent) @@ -287,11 +296,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val properties = new util.HashMap[String, String]() properties.put("prop-1", "1") - val table = catalog.createTable(testIdent, schema, emptyTrans, properties) + catalog.createTable(testIdent, schema, emptyTrans, properties) + val table = catalog.loadTable(testIdent) assert(filterV2TableProperties(table.properties) == Map("prop-1" -> "1")) - val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-2", "2")) + catalog.alterTable(testIdent, TableChange.setProperty("prop-2", "2")) + val updated = catalog.loadTable(testIdent) assert(filterV2TableProperties(updated.properties) == Map("prop-1" -> "1", "prop-2" -> "2")) val loaded = catalog.loadTable(testIdent) @@ -306,11 +317,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val properties = new util.HashMap[String, String]() properties.put("prop-1", "1") - val table = catalog.createTable(testIdent, schema, emptyTrans, properties) + catalog.createTable(testIdent, schema, emptyTrans, properties) + val table = catalog.loadTable(testIdent) assert(filterV2TableProperties(table.properties) == Map("prop-1" -> "1")) - val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + val updated = catalog.loadTable(testIdent) assert(filterV2TableProperties(updated.properties) == Map()) val loaded = catalog.loadTable(testIdent) @@ -322,11 +335,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: remove missing property") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(filterV2TableProperties(table.properties) == Map()) - val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + val updated = catalog.loadTable(testIdent) assert(filterV2TableProperties(updated.properties) == Map()) val loaded = catalog.loadTable(testIdent) @@ -338,11 +353,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add top-level column") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) - val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType)) + catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType)) + val updated = catalog.loadTable(testIdent) assert(updated.schema == schema.add("ts", TimestampType)) } @@ -350,12 +367,14 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add required column") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType, false)) + val updated = catalog.loadTable(testIdent) assert(updated.schema == schema.add("ts", TimestampType, nullable = false)) } @@ -363,12 +382,14 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add column with comment") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType, false, "comment text")) + val updated = catalog.loadTable(testIdent) val field = StructField("ts", TimestampType, nullable = false).withComment("comment text") assert(updated.schema == schema.add(field)) @@ -380,12 +401,14 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val tableSchema = schema.add("point", pointStruct) - val table = catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == tableSchema) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.addColumn(Array("point", "z"), DoubleType)) + val updated = catalog.loadTable(testIdent) val expectedSchema = schema.add("point", pointStruct.add("z", DoubleType)) @@ -395,7 +418,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add column to primitive field fails") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) @@ -413,7 +437,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add field to missing column fails") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) @@ -429,11 +454,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: update column data type") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) - val updated = catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType)) + catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType)) + val updated = catalog.loadTable(testIdent) val expectedSchema = new StructType().add("id", LongType).add("data", StringType) assert(updated.schema == expectedSchema) @@ -445,12 +472,14 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val originalSchema = new StructType() .add("id", IntegerType, nullable = false) .add("data", StringType) - val table = catalog.createTable(testIdent, originalSchema, emptyTrans, emptyProps) + catalog.createTable(testIdent, originalSchema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == originalSchema) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.updateColumnNullability(Array("id"), true)) + val updated = catalog.loadTable(testIdent) val expectedSchema = new StructType().add("id", IntegerType).add("data", StringType) assert(updated.schema == expectedSchema) @@ -459,7 +488,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: update missing column fails") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) @@ -475,12 +505,14 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add comment") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "comment text")) + val updated = catalog.loadTable(testIdent) val expectedSchema = new StructType() .add("id", IntegerType, nullable = true, "comment text") @@ -491,7 +523,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: replace comment") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) @@ -501,8 +534,9 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { .add("id", IntegerType, nullable = true, "replacement comment") .add("data", StringType) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "replacement comment")) + val updated = catalog.loadTable(testIdent) assert(updated.schema == expectedSchema) } @@ -510,7 +544,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: add comment to missing column fails") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) @@ -526,11 +561,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: rename top-level column") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) - val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("id"), "some_id")) + catalog.alterTable(testIdent, TableChange.renameColumn(Array("id"), "some_id")) + val updated = catalog.loadTable(testIdent) val expectedSchema = new StructType().add("some_id", IntegerType).add("data", StringType) @@ -543,12 +580,14 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val tableSchema = schema.add("point", pointStruct) - val table = catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == tableSchema) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.renameColumn(Array("point", "x"), "first")) + val updated = catalog.loadTable(testIdent) val newPointStruct = new StructType().add("first", DoubleType).add("y", DoubleType) val expectedSchema = schema.add("point", newPointStruct) @@ -562,12 +601,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val tableSchema = schema.add("point", pointStruct) - val table = catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == tableSchema) - val updated = catalog.alterTable(testIdent, - TableChange.renameColumn(Array("point"), "p")) + catalog.alterTable(testIdent, TableChange.renameColumn(Array("point"), "p")) + val updated = catalog.loadTable(testIdent) val newPointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val expectedSchema = schema.add("p", newPointStruct) @@ -578,7 +618,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: rename missing column fails") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) @@ -597,13 +638,15 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val tableSchema = schema.add("point", pointStruct) - val table = catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == tableSchema) - val updated = catalog.alterTable(testIdent, + catalog.alterTable(testIdent, TableChange.renameColumn(Array("point", "x"), "first"), TableChange.renameColumn(Array("point", "y"), "second")) + val updated = catalog.loadTable(testIdent) val newPointStruct = new StructType().add("first", DoubleType).add("second", DoubleType) val expectedSchema = schema.add("point", newPointStruct) @@ -614,12 +657,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: delete top-level column") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) - val updated = catalog.alterTable(testIdent, - TableChange.deleteColumn(Array("id"), false)) + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("id"), false)) + val updated = catalog.loadTable(testIdent) val expectedSchema = new StructType().add("data", StringType) assert(updated.schema == expectedSchema) @@ -631,12 +675,13 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val tableSchema = schema.add("point", pointStruct) - val table = catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == tableSchema) - val updated = catalog.alterTable(testIdent, - TableChange.deleteColumn(Array("point", "y"), false)) + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "y"), false)) + val updated = catalog.loadTable(testIdent) val newPointStruct = new StructType().add("x", DoubleType) val expectedSchema = schema.add("point", newPointStruct) @@ -647,7 +692,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { test("alterTable: delete missing column fails") { val catalog = newCatalog() - val table = catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == schema) @@ -669,7 +715,8 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val tableSchema = schema.add("point", pointStruct) - val table = catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + catalog.createTable(testIdent, tableSchema, emptyTrans, emptyProps) + val table = catalog.loadTable(testIdent) assert(table.schema == tableSchema) @@ -700,23 +747,27 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(!catalog.tableExists(testIdent)) // default location - val t1 = catalog.createTable(testIdent, schema, emptyTrans, emptyProps).asInstanceOf[V1Table] + catalog.createTable(testIdent, schema, emptyTrans, emptyProps) + val t1 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t1.catalogTable.location === spark.sessionState.catalog.defaultTablePath(testIdent.asTableIdentifier)) // relative path - val t2 = catalog.alterTable(testIdent, - TableChange.setProperty(TableCatalog.PROP_LOCATION, "relative/path")).asInstanceOf[V1Table] + catalog.alterTable(testIdent, + TableChange.setProperty(TableCatalog.PROP_LOCATION, "relative/path")) + val t2 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t2.catalogTable.location === makeQualifiedPathWithWarehouse("db.db/relative/path")) // absolute path without scheme - val t3 = catalog.alterTable(testIdent, - TableChange.setProperty(TableCatalog.PROP_LOCATION, "/absolute/path")).asInstanceOf[V1Table] + catalog.alterTable(testIdent, + TableChange.setProperty(TableCatalog.PROP_LOCATION, "/absolute/path")) + val t3 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t3.catalogTable.location.toString === "file:///absolute/path") // absolute path with scheme - val t4 = catalog.alterTable(testIdent, TableChange.setProperty( - TableCatalog.PROP_LOCATION, "file:/absolute/path")).asInstanceOf[V1Table] + catalog.alterTable(testIdent, TableChange.setProperty( + TableCatalog.PROP_LOCATION, "file:/absolute/path")) + val t4 = catalog.loadTable(testIdent).asInstanceOf[V1Table] assert(t4.catalogTable.location.toString === "file:/absolute/path") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala index 6b85911dca773..078c708cc3fdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -35,7 +36,11 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { val tempDir = Utils.createTempDir() val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" - val defaultMetadata = new MetadataBuilder().putLong("scale", 0).build() + + def defaultMetadata(dataType: DataType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", dataType.isInstanceOf[NumericType]) + .build() override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) @@ -137,8 +142,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { test("load a table") { val t = spark.table("h2.test.people") val expectedSchema = new StructType() - .add("NAME", VarcharType(32), true, defaultMetadata) - .add("ID", IntegerType, true, defaultMetadata) + .add("NAME", VarcharType(32), true, defaultMetadata(VarcharType(32))) + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedSchema)) Seq( "h2.test.not_existing_table" -> "`h2`.`test`.`not_existing_table`", @@ -180,13 +185,13 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql(s"ALTER TABLE $tableName ADD COLUMNS (C1 INTEGER, C2 STRING)") var t = spark.table(tableName) var expectedSchema = new StructType() - .add("ID", IntegerType, true, defaultMetadata) - .add("C1", IntegerType, true, defaultMetadata) - .add("C2", StringType, true, defaultMetadata) + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) + .add("C1", IntegerType, true, defaultMetadata(IntegerType)) + .add("C2", StringType, true, defaultMetadata(StringType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tableName ADD COLUMNS (c3 DOUBLE)") t = spark.table(tableName) - expectedSchema = expectedSchema.add("c3", DoubleType, true, defaultMetadata) + expectedSchema = expectedSchema.add("c3", DoubleType, true, defaultMetadata(DoubleType)) assert(t.schema === expectedSchema) // Add already existing column checkError( @@ -224,8 +229,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql(s"ALTER TABLE $tableName RENAME COLUMN id TO C") val t = spark.table(tableName) val expectedSchema = new StructType() - .add("C", IntegerType, true, defaultMetadata) - .add("C0", IntegerType, true, defaultMetadata) + .add("C", IntegerType, true, defaultMetadata(IntegerType)) + .add("C0", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) // Rename to already existing column checkError( @@ -263,7 +268,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql(s"ALTER TABLE $tableName DROP COLUMN C1") sql(s"ALTER TABLE $tableName DROP COLUMN c3") val t = spark.table(tableName) - val expectedSchema = new StructType().add("C2", IntegerType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("C2", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) // Drop not existing column val msg = intercept[AnalysisException] { @@ -292,8 +298,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql(s"ALTER TABLE $tableName ALTER COLUMN deptno TYPE DOUBLE") val t = spark.table(tableName) val expectedSchema = new StructType() - .add("ID", DoubleType, true, defaultMetadata) - .add("deptno", DoubleType, true, defaultMetadata) + .add("ID", DoubleType, true, defaultMetadata(DoubleType)) + .add("deptno", DoubleType, true, defaultMetadata(DoubleType)) assert(t.schema === expectedSchema) // Update not existing column val msg1 = intercept[AnalysisException] { @@ -330,8 +336,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql(s"ALTER TABLE $tableName ALTER COLUMN deptno DROP NOT NULL") val t = spark.table(tableName) val expectedSchema = new StructType() - .add("ID", IntegerType, true, defaultMetadata) - .add("deptno", IntegerType, true, defaultMetadata) + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) + .add("deptno", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) // Update nullability of not existing column val msg = intercept[AnalysisException] { @@ -387,8 +393,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql(s"CREATE TABLE $tableName (c1 INTEGER NOT NULL, c2 INTEGER)") var t = spark.table(tableName) var expectedSchema = new StructType() - .add("c1", IntegerType, true, defaultMetadata) - .add("c2", IntegerType, true, defaultMetadata) + .add("c1", IntegerType, true, defaultMetadata(IntegerType)) + .add("c2", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { @@ -401,8 +407,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql(s"ALTER TABLE $tableName RENAME COLUMN C2 TO c3") expectedSchema = new StructType() - .add("c1", IntegerType, true, defaultMetadata) - .add("c3", IntegerType, true, defaultMetadata) + .add("c1", IntegerType, true, defaultMetadata(IntegerType)) + .add("c3", IntegerType, true, defaultMetadata(IntegerType)) t = spark.table(tableName) assert(t.schema === expectedSchema) } @@ -416,7 +422,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql(s"ALTER TABLE $tableName DROP COLUMN C3") - expectedSchema = new StructType().add("c1", IntegerType, true, defaultMetadata) + expectedSchema = new StructType() + .add("c1", IntegerType, true, defaultMetadata(IntegerType)) t = spark.table(tableName) assert(t.schema === expectedSchema) } @@ -430,7 +437,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql(s"ALTER TABLE $tableName ALTER COLUMN C1 TYPE DOUBLE") - expectedSchema = new StructType().add("c1", DoubleType, true, defaultMetadata) + expectedSchema = new StructType() + .add("c1", DoubleType, true, defaultMetadata(DoubleType)) t = spark.table(tableName) assert(t.schema === expectedSchema) } @@ -444,7 +452,8 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql(s"ALTER TABLE $tableName ALTER COLUMN C1 DROP NOT NULL") - expectedSchema = new StructType().add("c1", DoubleType, true, defaultMetadata) + expectedSchema = new StructType() + .add("c1", DoubleType, true, defaultMetadata(IntegerType)) t = spark.table(tableName) assert(t.schema === expectedSchema) } @@ -506,10 +515,24 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql(s"ALTER TABLE $tableName ALTER COLUMN deptno TYPE VARCHAR(30)") val t = spark.table(tableName) val expected = new StructType() - .add("ID", CharType(10), true, defaultMetadata) - .add("deptno", VarcharType(30), true, defaultMetadata) + .add("ID", CharType(10), true, defaultMetadata(CharType(10))) + .add("deptno", VarcharType(30), true, defaultMetadata(VarcharType(30))) val replaced = CharVarcharUtils.replaceCharVarcharWithStringInSchema(expected) assert(t.schema === replaced) } } + + test("SPARK-45449: Cache Invalidation Issue with JDBC Table") { + withTable("h2.test.cache_t") { + withConnection { conn => + conn.prepareStatement( + """CREATE TABLE "test"."cache_t" (id decimal(25) PRIMARY KEY NOT NULL, + |name TEXT(32) NOT NULL)""".stripMargin).executeUpdate() + } + sql("INSERT OVERWRITE h2.test.cache_t SELECT 1 AS id, 'a' AS name") + sql("CACHE TABLE t1 SELECT id, name FROM h2.test.cache_t") + val plan = sql("select * from t1").queryExecution.sparkPlan + assert(plan.isInstanceOf[InMemoryTableScanExec]) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 4f78833abdb9f..a4a3d76db313a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -26,10 +26,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestData} import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} -class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { +class OuterJoinSuite extends SparkPlanTest with SharedSparkSession with SQLTestData { + setupTestData() private val EnsureRequirements = new EnsureRequirements() @@ -325,4 +326,21 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { (null, null, 7, 7.0) ) ) + + testWithWholeStageCodegenOnAndOff( + "SPARK-46037: ShuffledHashJoin build left with left outer join, codegen off") { _ => + def join(hint: String): DataFrame = { + sql( + s""" + |SELECT /*+ $hint */ * + |FROM testData t1 + |LEFT OUTER JOIN + |testData2 t2 + |ON key = a AND concat(value, b) = '12' + |""".stripMargin) + } + val df1 = join("SHUFFLE_HASH(t1)") + val df2 = join("SHUFFLE_MERGE(t1)") + checkAnswer(df1, identity, df2.collect().toSeq) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 6347757e178c0..5cdbdc27b3259 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -960,6 +960,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m", initValue = -1).isZero()) assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m", initValue = 5).isZero()) } + + test("SQLMetric#toInfoUpdate") { + assert(SQLMetrics.createSizeMetric(sparkContext, name = "m").toInfoUpdate.update === Some(-1)) + assert(SQLMetrics.createMetric(sparkContext, name = "m").toInfoUpdate.update === Some(0)) + } } case class CustomFileCommitProtocol( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 81667d52e16ae..e964867cb86ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -311,8 +311,8 @@ object InputOutputMetricsHelper { res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead var maxOutputRows = 0L - for (accum <- taskEnd.taskMetrics.externalAccums) { - val info = accum.toInfo(Some(accum.value), None) + taskEnd.taskMetrics.withExternalAccums(_.foreach { accum => + val info = accum.toInfoUpdate if (info.name.toString.contains("number of output rows")) { info.update match { case Some(n: Number) => @@ -322,7 +322,7 @@ object InputOutputMetricsHelper { case _ => // Ignore. } } - } + }) res.sumMaxOutputRows += maxOutputRows } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index d86faec1a7bbd..9a168dc80a03a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest} -import org.apache.spark.sql.functions.count +import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.functions.{col, count} import org.apache.spark.sql.test.SharedSparkSession class PythonUDFSuite extends QueryTest with SharedSparkSession { @@ -111,4 +111,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { val pandasTestUDF = TestGroupedAggPandasUDF(name = udfName) assert(df.agg(pandasTestUDF(df("id"))).schema.fieldNames.exists(_.startsWith(udfName))) } + + test("SPARK-48666: Python UDF execution against partitioned column") { + assume(shouldTestPythonUDFs) + withTable("t") { + spark.range(1).selectExpr("id AS t", "(id + 1) AS p").write.partitionBy("p").saveAsTable("t") + val table = spark.table("t") + val newTable = table.withColumn("new_column", pythonTestUDF(table("p"))) + val df = newTable.as("t1").join( + newTable.as("t2"), col("t1.new_column") === col("t2.new_column")) + checkAnswer(df, Row(0, 1, 1, 0, 1, 1)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 980d532dd4779..08f245135f589 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -33,6 +33,18 @@ class HDFSMetadataLogSuite extends SharedSparkSession { private implicit def toOption[A](a: A): Option[A] = Option(a) + test("SPARK-46339: Directory with number name should not be treated as metadata log") { + withTempDir { temp => + val dir = new File(temp, "dir") + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) + assert(metadataLog.metadataPath.toString.endsWith("/dir")) + + // Create a directory with batch id 0 + new File(dir, "0").mkdir() + assert(metadataLog.getLatest() === None) + } + } + test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index e31b05c362f6a..973c1e0cb3b0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -19,21 +19,45 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ import java.nio.charset.Charset +import java.util.concurrent.Executors +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ import scala.language.implicitConversions import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.scalactic.source.Position import org.scalatest.Tag import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.CreateAtomicTestManager +import org.apache.spark.sql.execution.streaming.{CreateAtomicTestManager, FileSystemBasedCheckpointFileManager} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.{CancellableFSDataOutputStream, RenameBasedFSDataOutputStream} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.tags.SlowSQLTest import org.apache.spark.util.{ThreadUtils, Utils} +class NoOverwriteFileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends FileSystemBasedCheckpointFileManager(path, hadoopConf) { + + override def createAtomic(path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def renameTempFile(srcPath: Path, dstPath: Path, + overwriteIfPossible: Boolean): Unit = { + if (!fs.exists(dstPath)) { + // only write if a file does not exist at this location + super.renameTempFile(srcPath, dstPath, overwriteIfPossible) + } + } +} + trait RocksDBStateStoreChangelogCheckpointingTestUtil { val rocksdbChangelogCheckpointingConfKey: String = RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" @@ -214,6 +238,35 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + testWithChangelogCheckpointingEnabled("SPARK-45419: Do not reuse SST files" + + " in different RocksDB instances") { + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 0, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf) { db => + for (version <- 0 to 2) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 3.zip + db.doMaintenance() + // Roll back to version 1 and start to process data. + for (version <- 1 to 3) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // Upload snapshot 4.zip, should not reuse the SST files in 3.zip + db.doMaintenance() + } + + withDB(remoteDir, conf = conf) { db => + // Open the db to verify that the state in 4.zip is no corrupted. + db.load(4) + } + } + // A rocksdb instance with changelog checkpointing enabled should be able to load // an existing checkpoint without changelog. testWithChangelogCheckpointingEnabled( @@ -254,6 +307,14 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared db.load(version, readOnly = true) assert(db.iterator().map(toStr).toSet === Set((version.toString, version.toString))) } + + // recommit 60 to ensure that acquireLock is released for maintenance + for (version <- 60 to 60) { + db.load(version - 1) + db.put(version.toString, version.toString) + db.remove((version - 1).toString) + db.commit() + } // Check that snapshots and changelogs get purged correctly. db.doMaintenance() assert(snapshotVersionsPresent(remoteDir) === Seq(30, 60)) @@ -419,6 +480,41 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + testWithChangelogCheckpointingEnabled("RocksDBFileManager: " + + "background snapshot upload doesn't acquire RocksDB instance lock") { + // Create a custom ExecutionContext + implicit val ec: ExecutionContext = ExecutionContext + .fromExecutor(Executors.newSingleThreadExecutor()) + + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(lockAcquireTimeoutMs = 10000, minDeltasForSnapshot = 0) + new File(remoteDir).delete() // to make sure that the directory gets created + + withDB(remoteDir, conf = conf) { db => + db.load(0) + db.put("0", "0") + db.commit() + + // Acquire lock + db.load(1) + db.put("1", "1") + + // Run doMaintenance in another thread + val maintenanceFuture = Future { + db.doMaintenance() + } + + val timeout = 5.seconds + + // Ensure that maintenance task runs without being blocked by task thread + ThreadUtils.awaitResult(maintenanceFuture, timeout) + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + // Release lock + db.commit() + } + } + testWithChangelogCheckpointingEnabled("RocksDBFileManager: read and write changelog") { val dfsRootDir = new File(Utils.createTempDir().getAbsolutePath + "/state/1/1") val fileManager = new RocksDBFileManager( @@ -637,19 +733,19 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared // Save SAME version again with different checkpoint files and load back again to verify // whether files were overwritten. val cpFiles1_ = Seq( - "sst-file1.sst" -> 10, // same SST file as before, but same version, so should get copied + "sst-file1.sst" -> 10, // same SST file as before, this should get reused "sst-file2.sst" -> 25, // new SST file with same name as before, but different length "sst-file3.sst" -> 30, // new SST file "other-file1" -> 100, // same non-SST file as before, should not get copied "other-file2" -> 210, // new non-SST file with same name as before, but different length "other-file3" -> 300, // new non-SST file - "archive/00001.log" -> 1000, // same log file as before and version, so should get copied + "archive/00001.log" -> 1000, // same log file as before, this should get reused "archive/00002.log" -> 2500, // new log file with same name as before, but different length "archive/00003.log" -> 3000 // new log file ) saveCheckpointFiles(fileManager, cpFiles1_, version = 1, numKeys = 1001) - assert(numRemoteSSTFiles === 5, "shouldn't copy same files again") // 2 old + 3 new SST files - assert(numRemoteLogFiles === 5, "shouldn't copy same files again") // 2 old + 3 new log files + assert(numRemoteSSTFiles === 4, "shouldn't copy same files again") // 2 old + 2 new SST files + assert(numRemoteLogFiles === 4, "shouldn't copy same files again") // 2 old + 2 new log files loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 1, cpFiles1_, 1001) // Save another version and verify @@ -659,8 +755,8 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared "archive/00004.log" -> 4000 ) saveCheckpointFiles(fileManager, cpFiles2, version = 2, numKeys = 1501) - assert(numRemoteSSTFiles === 6) // 1 new file over earlier 5 files - assert(numRemoteLogFiles === 6) // 1 new file over earlier 5 files + assert(numRemoteSSTFiles === 5) // 1 new file over earlier 4 files + assert(numRemoteLogFiles === 5) // 1 new file over earlier 4 files loadAndVerifyCheckpointFiles(fileManager, verificationDir, version = 2, cpFiles2, 1501) // Loading an older version should work @@ -1123,6 +1219,437 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + test("time travel - validate successful RocksDB load") { + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 1, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf) { db => + for (version <- 0 to 1) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 2.zip + db.doMaintenance() + for (version <- Seq(2)) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 3.zip + db.doMaintenance() + // simulate db in another executor that override the zip file + withDB(remoteDir, conf = conf) { db1 => + for (version <- 0 to 1) { + db1.load(version) + db1.put(version.toString, version.toString) + db1.commit() + } + db1.doMaintenance() + } + db.load(2) + for (version <- Seq(2)) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 3.zip + db.doMaintenance() + // rollback to version 2 + db.load(2) + } + } + + test("time travel 2 - validate successful RocksDB load") { + Seq(1, 2).map(minDeltasForSnapshot => { + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = minDeltasForSnapshot, + compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf) { db => + for (version <- 0 to 1) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 2.zip + db.doMaintenance() + for (version <- 2 to 3) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + db.load(0) + // simulate db in another executor that override the zip file + withDB(remoteDir, conf = conf) { db1 => + for (version <- 0 to 1) { + db1.load(version) + db1.put(version.toString, version.toString) + db1.commit() + } + db1.doMaintenance() + } + for (version <- 2 to 3) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 4.zip + db.doMaintenance() + } + withDB(remoteDir, version = 4, conf = conf) { db => + } + }) + } + + test("time travel 3 - validate successful RocksDB load") { + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 0, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf) { db => + for (version <- 0 to 2) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 2.zip + db.doMaintenance() + for (version <- 1 to 3) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + // upload snapshot 4.zip + db.doMaintenance() + } + + withDB(remoteDir, version = 4, conf = conf) { db => + } + } + + testWithChangelogCheckpointingEnabled("time travel 4 -" + + " validate successful RocksDB load when metadata file is overwritten") { + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 2, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf) { db => + for (version <- 0 to 1) { + db.load(version) + db.put(version.toString, version.toString) + db.commit() + } + + // load previous version, and recreate the snapshot + db.load(1) + db.put("3", "3") + + // upload any latest snapshots so far + db.doMaintenance() + db.commit() + // upload newly created snapshot 2.zip + db.doMaintenance() + } + + // reload version 2 - should succeed + withDB(remoteDir, version = 2, conf = conf) { db => + } + } + + testWithChangelogCheckpointingEnabled("time travel 5 -" + + "validate successful RocksDB load when metadata file is not overwritten") { + // Ensure commit doesn't modify the latestSnapshot that doMaintenance will upload + val fmClass = "org.apache.spark.sql.execution.streaming.state." + + "NoOverwriteFileSystemBasedCheckpointFileManager" + withTempDir { dir => + val conf = dbConf.copy(minDeltasForSnapshot = 0) // create snapshot every commit + val hadoopConf = new Configuration() + hadoopConf.set(STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, fmClass) + + val remoteDir = dir.getCanonicalPath + withDB(remoteDir, conf = conf, hadoopConf = hadoopConf) { db => + db.load(0) + db.put("a", "1") + db.commit() + + // load previous version, and recreate the snapshot + db.load(0) + db.put("a", "1") + + // upload version 1 snapshot created above + db.doMaintenance() + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + db.commit() // create snapshot again + + // load version 1 - should succeed + withDB(remoteDir, version = 1, conf = conf, hadoopConf = hadoopConf) { db => + } + + // upload recently created snapshot + db.doMaintenance() + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + // load version 1 again - should succeed + withDB(remoteDir, version = 1, conf = conf, hadoopConf = hadoopConf) { db => + } + } + } + } + + test("validate Rocks DB SST files do not have a VersionIdMismatch" + + " when metadata file is not overwritten - scenario 1") { + val fmClass = "org.apache.spark.sql.execution.streaming.state." + + "NoOverwriteFileSystemBasedCheckpointFileManager" + withTempDir { dir => + val dbConf = RocksDBConf(StateStoreConf(new SQLConf())) + val hadoopConf = new Configuration() + hadoopConf.set(STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, fmClass) + + val remoteDir = dir.getCanonicalPath + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db1 => + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db2 => + // commit version 1 via db1 + db1.load(0) + db1.put("a", "1") + db1.put("b", "1") + + db1.commit() + + // commit version 1 via db2 + db2.load(0) + db2.put("a", "1") + db2.put("b", "1") + + db2.commit() + + // commit version 2 via db2 + db2.load(1) + db2.put("a", "2") + db2.put("b", "2") + + db2.commit() + + // reload version 1, this should succeed + db2.load(1) + db1.load(1) + + // reload version 2, this should succeed + db2.load(2) + db1.load(2) + } + } + } + } + + test("validate Rocks DB SST files do not have a VersionIdMismatch" + + " when metadata file is overwritten - scenario 1") { + withTempDir { dir => + val dbConf = RocksDBConf(StateStoreConf(new SQLConf())) + val hadoopConf = new Configuration() + val remoteDir = dir.getCanonicalPath + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db1 => + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db2 => + // commit version 1 via db1 + db1.load(0) + db1.put("a", "1") + db1.put("b", "1") + + db1.commit() + + // commit version 1 via db2 + db2.load(0) + db2.put("a", "1") + db2.put("b", "1") + + db2.commit() + + // commit version 2 via db2 + db2.load(1) + db2.put("a", "2") + db2.put("b", "2") + + db2.commit() + + // reload version 1, this should succeed + db2.load(1) + db1.load(1) + + // reload version 2, this should succeed + db2.load(2) + db1.load(2) + } + } + } + } + + test("validate Rocks DB SST files do not have a VersionIdMismatch" + + " when metadata file is not overwritten - scenario 2") { + val fmClass = "org.apache.spark.sql.execution.streaming.state." + + "NoOverwriteFileSystemBasedCheckpointFileManager" + withTempDir { dir => + val dbConf = RocksDBConf(StateStoreConf(new SQLConf())) + val hadoopConf = new Configuration() + hadoopConf.set(STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, fmClass) + + val remoteDir = dir.getCanonicalPath + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db1 => + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db2 => + // commit version 1 via db2 + db2.load(0) + db2.put("a", "1") + db2.put("b", "1") + + db2.commit() + + // commit version 1 via db1 + db1.load(0) + db1.put("a", "1") + db1.put("b", "1") + + db1.commit() + + // commit version 2 via db2 + db2.load(1) + db2.put("a", "2") + db2.put("b", "2") + + db2.commit() + + // reload version 1, this should succeed + db2.load(1) + db1.load(1) + + // reload version 2, this should succeed + db2.load(2) + db1.load(2) + } + } + } + } + + test("validate Rocks DB SST files do not have a VersionIdMismatch" + + " when metadata file is overwritten - scenario 2") { + withTempDir { dir => + val dbConf = RocksDBConf(StateStoreConf(new SQLConf())) + val hadoopConf = new Configuration() + val remoteDir = dir.getCanonicalPath + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db1 => + withDB(remoteDir, conf = dbConf, hadoopConf = hadoopConf) { db2 => + // commit version 1 via db2 + db2.load(0) + db2.put("a", "1") + db2.put("b", "1") + + db2.commit() + + // commit version 1 via db1 + db1.load(0) + db1.put("a", "1") + db1.put("b", "1") + + db1.commit() + + // commit version 2 via db2 + db2.load(1) + db2.put("a", "2") + db2.put("b", "2") + + db2.commit() + + // reload version 1, this should succeed + db2.load(1) + db1.load(1) + + // reload version 2, this should succeed + db2.load(2) + db1.load(2) + } + } + } + } + + test("ensure local files deleted on filesystem" + + " are cleaned from dfs file mapping") { + def getSSTFiles(dir: File): Set[File] = { + val sstFiles = new mutable.HashSet[File]() + dir.listFiles().foreach { f => + if (f.isDirectory) { + sstFiles ++= getSSTFiles(f) + } else { + if (f.getName.endsWith(".sst")) { + sstFiles.add(f) + } + } + } + sstFiles.toSet + } + + def filterAndDeleteSSTFiles(dir: File, filesToKeep: Set[File]): Unit = { + dir.listFiles().foreach { f => + if (f.isDirectory) { + filterAndDeleteSSTFiles(f, filesToKeep) + } else { + if (!filesToKeep.contains(f) && f.getName.endsWith(".sst")) { + logInfo(s"deleting ${f.getAbsolutePath} from local directory") + f.delete() + } + } + } + } + + withTempDir { dir => + withTempDir { localDir => + val sqlConf = new SQLConf() + val dbConf = RocksDBConf(StateStoreConf(sqlConf)) + logInfo(s"config set to ${dbConf.compactOnCommit}") + val hadoopConf = new Configuration() + val remoteDir = dir.getCanonicalPath + withDB(remoteDir = remoteDir, + conf = dbConf, + hadoopConf = hadoopConf, + localDir = localDir) { db => + db.load(0) + db.put("a", "1") + db.put("b", "1") + db.commit() + db.doMaintenance() + + // find all SST files written in version 1 + val sstFiles = getSSTFiles(localDir) + + // make more commits, this would generate more SST files and write + // them to remoteDir + for (version <- 1 to 10) { + db.load(version) + db.put("c", "1") + db.put("d", "1") + db.commit() + db.doMaintenance() + } + + // clean the SST files committed after version 1 from local + // filesystem. This is similar to what a process like compaction + // where multiple L0 SST files can be merged into a single L1 file + filterAndDeleteSSTFiles(localDir, sstFiles) + + // reload 2, and overwrite commit for version 3, this should not + // reuse any locally deleted files as they should be removed from the mapping + db.load(2) + db.put("e", "1") + db.put("f", "1") + db.commit() + db.doMaintenance() + + // clean local state + db.load(0) + + // reload version 3, should be successful + db.load(3) + } + } + } + } + private def sqlConf = SQLConf.get.clone() private def dbConf = RocksDBConf(StateStoreConf(sqlConf)) @@ -1131,12 +1658,16 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared remoteDir: String, version: Int = 0, conf: RocksDBConf = dbConf, - hadoopConf: Configuration = new Configuration())( + hadoopConf: Configuration = new Configuration(), + localDir: File = Utils.createTempDir())( func: RocksDB => T): T = { var db: RocksDB = null try { db = new RocksDB( - remoteDir, conf = conf, hadoopConf = hadoopConf, + remoteDir, + conf = conf, + localRootDir = localDir, + hadoopConf = hadoopConf, loggingId = s"[Thread-${Thread.currentThread.getId}]") db.load(version) func(db) @@ -1161,7 +1692,11 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared numKeys: Int): Unit = { val checkpointDir = Utils.createTempDir().getAbsolutePath // local dir to create checkpoints generateFiles(checkpointDir, fileToLengths) - fileManager.saveCheckpointToDfs(checkpointDir, version, numKeys) + fileManager.saveCheckpointToDfs( + checkpointDir, + version, + numKeys, + fileManager.captureFileMapReference()) } def loadAndVerifyCheckpointFiles( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 02aa12b325ff7..512a095250ae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -272,6 +272,44 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("SPARK-48105: state store unload/close happens during the maintenance") { + tryWithProviderResource( + newStoreProvider(opId = Random.nextInt(), partition = 0, minDeltasForSnapshot = 1)) { + provider => + val store = provider.getStore(0).asInstanceOf[provider.HDFSBackedStateStore] + val values = (1 to 20) + val keys = values.map(i => ("a" + i)) + keys.zip(values).map{case (k, v) => put(store, k, 0, v)} + // commit state store with 20 keys. + store.commit() + // get the state store iterator: mimic the case which the iterator is hold in the + // maintenance thread. + val storeIterator = store.iterator() + + // the store iterator should still be valid as the maintenance thread may have already + // hold it and is doing snapshotting even though the state store is unloaded. + val outputKeys = new mutable.ArrayBuffer[String] + val outputValues = new mutable.ArrayBuffer[Int] + var cnt = 0 + while (storeIterator.hasNext) { + if (cnt == 10) { + // Mimic the case where the provider is loaded in another executor in the middle of + // iteration. When this happens, the provider will be unloaded and closed in + // current executor. + provider.close() + } + val unsafeRowPair = storeIterator.next() + val (key, _) = keyRowToData(unsafeRowPair.key) + outputKeys.append(key) + outputValues.append(valueRowToData(unsafeRowPair.value)) + + cnt = cnt + 1 + } + assert(keys.sorted === outputKeys.sorted) + assert(values.sorted === outputValues.sorted) + } + } + test("maintenance") { val conf = new SparkConf() .setMaster("local") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala new file mode 100644 index 0000000000000..88237cd09ac71 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.ui + +import org.apache.spark.SparkFunSuite + +class SparkPlanGraphSuite extends SparkFunSuite { + test("SPARK-47503: name of a node should be escaped even if there is no metrics") { + val planGraphNode = new SparkPlanGraphNode( + id = 24, + name = "Scan JDBCRelation(\"test-schema\".tickets) [numPartitions=1]", + desc = "Scan JDBCRelation(\"test-schema\".tickets) [numPartitions=1] " + + "[ticket_no#0] PushedFilters: [], ReadSchema: struct", + metrics = List( + SQLPlanMetric( + name = "number of output rows", + accumulatorId = 75, + metricType = "sum" + ), + SQLPlanMetric( + name = "JDBC query execution time", + accumulatorId = 35, + metricType = "nsTiming"))) + val dotNode = planGraphNode.makeDotNode(Map.empty[Long, String]) + val expectedDotNode = " 24 [labelType=\"html\" label=\"
    " + + "Scan JDBCRelation(\\\"test-schema\\\".tickets) [numPartitions=1]

    \"];" + + assertResult(expectedDotNode)(dotNode) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 516be9a4e5958..a40a416bbb5a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -473,6 +473,186 @@ class ColumnVectorSuite extends SparkFunSuite { assert(testVector.getDoubles(0, 3)(2) == 1342.17729d) } + def check(expected: Seq[Any], testVector: WritableColumnVector): Unit = { + expected.zipWithIndex.foreach { + case (v: Integer, idx) => + assert(testVector.getInt(idx) == v) + assert(testVector.getInts(0, testVector.capacity)(idx) == v) + case (v: Short, idx) => + assert(testVector.getShort(idx) == v) + assert(testVector.getShorts(0, testVector.capacity)(idx) == v) + case (v: Byte, idx) => + assert(testVector.getByte(idx) == v) + assert(testVector.getBytes(0, testVector.capacity)(idx) == v) + case (v: Long, idx) => + assert(testVector.getLong(idx) == v) + assert(testVector.getLongs(0, testVector.capacity)(idx) == v) + case (v: Float, idx) => + assert(testVector.getFloat(idx) == v) + assert(testVector.getFloats(0, testVector.capacity)(idx) == v) + case (v: Double, idx) => + assert(testVector.getDouble(idx) == v) + assert(testVector.getDoubles(0, testVector.capacity)(idx) == v) + case (null, idx) => testVector.isNullAt(idx) + case (_, idx) => assert(false, s"Unexpected value at $idx") + } + + // Verify ColumnarArray.copy() works as expected + val arr = new ColumnarArray(testVector, 0, testVector.capacity) + assert(arr.toSeq(testVector.dataType) == expected) + assert(arr.copy().toSeq(testVector.dataType) == expected) + + if (expected.nonEmpty) { + val withOffset = new ColumnarArray(testVector, 1, testVector.capacity - 1) + assert(withOffset.toSeq(testVector.dataType) == expected.tail) + assert(withOffset.copy().toSeq(testVector.dataType) == expected.tail) + } + } + + testVectors("getInts with dictionary and nulls", 3, IntegerType) { testVector => + // Validate without dictionary + val expected = Seq(1, null, 3) + expected.foreach { + case i: Integer => testVector.appendInt(i) + case _ => testVector.appendNull() + } + check(expected, testVector) + + // Validate with dictionary + val expectedDictionary = Seq(7, null, 9) + val dictArray = (Seq(-1, -1) ++ expectedDictionary.map { + case i: Integer => i.toInt + case _ => -1 + }).toArray + val dict = new ColumnDictionary(dictArray) + testVector.setDictionary(dict) + testVector.reserveDictionaryIds(3) + testVector.getDictionaryIds.putInt(0, 2) + testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored + testVector.getDictionaryIds.putInt(2, 4) + check(expectedDictionary, testVector) + } + + testVectors("getShorts with dictionary and nulls", 3, ShortType) { testVector => + // Validate without dictionary + val expected = Seq(1.toShort, null, 3.toShort) + expected.foreach { + case i: Short => testVector.appendShort(i) + case _ => testVector.appendNull() + } + check(expected, testVector) + + // Validate with dictionary + val expectedDictionary = Seq(7.toShort, null, 9.toShort) + val dictArray = (Seq(-1, -1) ++ expectedDictionary.map { + case i: Short => i.toInt + case _ => -1 + }).toArray + val dict = new ColumnDictionary(dictArray) + testVector.setDictionary(dict) + testVector.reserveDictionaryIds(3) + testVector.getDictionaryIds.putInt(0, 2) + testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored + testVector.getDictionaryIds.putInt(2, 4) + check(expectedDictionary, testVector) + } + + testVectors("getBytes with dictionary and nulls", 3, ByteType) { testVector => + // Validate without dictionary + val expected = Seq(1.toByte, null, 3.toByte) + expected.foreach { + case i: Byte => testVector.appendByte(i) + case _ => testVector.appendNull() + } + check(expected, testVector) + + // Validate with dictionary + val expectedDictionary = Seq(7.toByte, null, 9.toByte) + val dictArray = (Seq(-1, -1) ++ expectedDictionary.map { + case i: Byte => i.toInt + case _ => -1 + }).toArray + val dict = new ColumnDictionary(dictArray) + testVector.setDictionary(dict) + testVector.reserveDictionaryIds(3) + testVector.getDictionaryIds.putInt(0, 2) + testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored + testVector.getDictionaryIds.putInt(2, 4) + check(expectedDictionary, testVector) + } + + testVectors("getLongs with dictionary and nulls", 3, LongType) { testVector => + // Validate without dictionary + val expected = Seq(2147483L, null, 2147485L) + expected.foreach { + case i: Long => testVector.appendLong(i) + case _ => testVector.appendNull() + } + check(expected, testVector) + + // Validate with dictionary + val expectedDictionary = Seq(2147483648L, null, 2147483650L) + val dictArray = (Seq(-1L, -1L) ++ expectedDictionary.map { + case i: Long => i + case _ => -1L + }).toArray + val dict = new ColumnDictionary(dictArray) + testVector.setDictionary(dict) + testVector.reserveDictionaryIds(3) + testVector.getDictionaryIds.putInt(0, 2) + testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored + testVector.getDictionaryIds.putInt(2, 4) + check(expectedDictionary, testVector) + } + + testVectors("getFloats with dictionary and nulls", 3, FloatType) { testVector => + // Validate without dictionary + val expected = Seq(1.1f, null, 3.3f) + expected.foreach { + case i: Float => testVector.appendFloat(i) + case _ => testVector.appendNull() + } + check(expected, testVector) + + // Validate with dictionary + val expectedDictionary = Seq(0.1f, null, 0.3f) + val dictArray = (Seq(-1f, -1f) ++ expectedDictionary.map { + case i: Float => i + case _ => -1f + }).toArray + val dict = new ColumnDictionary(dictArray) + testVector.setDictionary(dict) + testVector.reserveDictionaryIds(3) + testVector.getDictionaryIds.putInt(0, 2) + testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored + testVector.getDictionaryIds.putInt(2, 4) + check(expectedDictionary, testVector) + } + + testVectors("getDoubles with dictionary and nulls", 3, DoubleType) { testVector => + // Validate without dictionary + val expected = Seq(1.1d, null, 3.3d) + expected.foreach { + case i: Double => testVector.appendDouble(i) + case _ => testVector.appendNull() + } + check(expected, testVector) + + // Validate with dictionary + val expectedDictionary = Seq(1342.17727d, null, 1342.17729d) + val dictArray = (Seq(-1d, -1d) ++ expectedDictionary.map { + case i: Double => i + case _ => -1d + }).toArray + val dict = new ColumnDictionary(dictArray) + testVector.setDictionary(dict) + testVector.reserveDictionaryIds(3) + testVector.getDictionaryIds.putInt(0, 2) + testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored + testVector.getDictionaryIds.putInt(2, 4) + check(expectedDictionary, testVector) + } + test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { withVector(new OffHeapColumnVector(8, arrayType)) { testVector => val data = testVector.arrayData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 4dd93983e87e3..a02137a56aacc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.expressions -import scala.collection.parallel.immutable.ParVector - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +24,7 @@ import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.tags.SlowSQLTest -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} @SlowSQLTest class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { @@ -197,8 +195,11 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // The encrypt expression includes a random initialization vector to its encrypted result classOf[AesEncrypt].getName) - val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector) - parFuncs.foreach { funcId => + ThreadUtils.parmap( + spark.sessionState.functionRegistry.listFunction(), + prefix = "ExpressionInfoSuite-check-outputs-of-expression-examples", + maxThreads = Runtime.getRuntime.availableProcessors + ) { funcId => // Examples can change settings. We clone the session to prevent tests clashing. val clonedSpark = spark.cloneSession() // Coalescing partitions can change result order, so disable it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index c6bf220e45d52..470bd15cc418c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -635,7 +635,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val description = "this is a test table" withTable("t") { - withTempDir { dir => + withTempDir { baseDir => + val dir = new File(baseDir, "test%prefix") spark.catalog.createTable( tableName = "t", source = "json", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 93b6652d516cc..f4702ee9edb3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -29,7 +29,7 @@ import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ import org.apache.spark.{SparkException, SparkSQLException} -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Observation, QueryTest, Row} import org.apache.spark.sql.catalyst.{analysis, TableIdentifier} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.ShowCreateTable @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCo import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.functions.{lit, percentile_approx} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSparkSession @@ -76,7 +77,10 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } } - val defaultMetadata = new MetadataBuilder().putLong("scale", 0).build() + def defaultMetadata(dataType: DataType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", dataType.isInstanceOf[NumericType]) + .build() override def beforeAll(): Unit = { super.beforeAll() @@ -906,13 +910,18 @@ class JDBCSuite extends QueryTest with SharedSparkSession { test("MySQLDialect catalyst type mapping") { val mySqlDialect = JdbcDialects.get("jdbc:mysql") - val metadata = new MetadataBuilder() + val metadata = new MetadataBuilder().putBoolean("isSigned", value = true) assert(mySqlDialect.getCatalystType(java.sql.Types.VARBINARY, "BIT", 2, metadata) == Some(LongType)) assert(metadata.build().contains("binarylong")) assert(mySqlDialect.getCatalystType(java.sql.Types.VARBINARY, "BIT", 1, metadata) == None) assert(mySqlDialect.getCatalystType(java.sql.Types.BIT, "TINYINT", 1, metadata) == Some(BooleanType)) + assert(mySqlDialect.getCatalystType(java.sql.Types.TINYINT, "TINYINT", 1, metadata) == + Some(ByteType)) + metadata.putBoolean("isSigned", value = false) + assert(mySqlDialect.getCatalystType(java.sql.Types.TINYINT, "TINYINT", 1, metadata) === + Some(ShortType)) } test("SPARK-35446: MySQLDialect type mapping of float") { @@ -1052,10 +1061,9 @@ class JDBCSuite extends QueryTest with SharedSparkSession { val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") val table = "weblogs" - val defaultQuery = s"SELECT * FROM $table WHERE 1=0" - val limitQuery = s"SELECT 1 FROM $table LIMIT 1" - assert(MySQL.getTableExistsQuery(table) == limitQuery) - assert(Postgres.getTableExistsQuery(table) == limitQuery) + val defaultQuery = s"SELECT 1 FROM $table WHERE 1=0" + assert(MySQL.getTableExistsQuery(table) == defaultQuery) + assert(Postgres.getTableExistsQuery(table) == defaultQuery) assert(db2.getTableExistsQuery(table) == defaultQuery) assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) @@ -1271,7 +1279,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { test("SPARK 12941: The data type mapping for StringType to Oracle") { val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") assert(oracleDialect.getJDBCType(StringType). - map(_.databaseTypeDefinition).get == "CLOB") + map(_.databaseTypeDefinition).get == "VARCHAR2(255)") } test("SPARK-16625: General data types to be mapped to Oracle") { @@ -1289,7 +1297,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { assert(getJdbcType(oracleDialect, DoubleType) == "NUMBER(19, 4)") assert(getJdbcType(oracleDialect, ByteType) == "NUMBER(3)") assert(getJdbcType(oracleDialect, ShortType) == "NUMBER(5)") - assert(getJdbcType(oracleDialect, StringType) == "CLOB") + assert(getJdbcType(oracleDialect, StringType) == "VARCHAR2(255)") assert(getJdbcType(oracleDialect, BinaryType) == "BLOB") assert(getJdbcType(oracleDialect, DateType) == "DATE") assert(getJdbcType(oracleDialect, TimestampType) == "TIMESTAMP") @@ -1361,8 +1369,8 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } test("SPARK-16848: jdbc API throws an exception for user specified schema") { - val schema = StructType(Seq(StructField("name", StringType, false, defaultMetadata), - StructField("theid", IntegerType, false, defaultMetadata))) + val schema = StructType(Seq(StructField("name", StringType, false, defaultMetadata(StringType)), + StructField("theid", IntegerType, false, defaultMetadata(IntegerType)))) val parts = Array[String]("THEID < 2", "THEID >= 2") val e1 = intercept[AnalysisException] { spark.read.schema(schema).jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties()) @@ -1382,8 +1390,9 @@ class JDBCSuite extends QueryTest with SharedSparkSession { props.put("customSchema", customSchema) val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props) assert(df.schema.size === 2) - val expectedSchema = new StructType(CatalystSqlParser.parseTableSchema(customSchema).map( - f => StructField(f.name, f.dataType, f.nullable, defaultMetadata)).toArray) + val structType = CatalystSqlParser.parseTableSchema(customSchema) + val expectedSchema = new StructType(structType.map( + f => StructField(f.name, f.dataType, f.nullable, defaultMetadata(f.dataType))).toArray) assert(df.schema === CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedSchema)) assert(df.count() === 3) } @@ -1401,7 +1410,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { val df = sql("select * from people_view") assert(df.schema.length === 2) val expectedSchema = new StructType(CatalystSqlParser.parseTableSchema(customSchema) - .map(f => StructField(f.name, f.dataType, f.nullable, defaultMetadata)).toArray) + .map(f => StructField(f.name, f.dataType, f.nullable, defaultMetadata(f.dataType))).toArray) assert(df.schema === CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedSchema)) assert(df.count() === 3) @@ -1548,8 +1557,9 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } test("jdbc data source shouldn't have unnecessary metadata in its schema") { - var schema = StructType(Seq(StructField("NAME", VarcharType(32), true, defaultMetadata), - StructField("THEID", IntegerType, true, defaultMetadata))) + var schema = StructType( + Seq(StructField("NAME", VarcharType(32), true, defaultMetadata(VarcharType(32))), + StructField("THEID", IntegerType, true, defaultMetadata(IntegerType)))) schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) val df = spark.read.format("jdbc") .option("Url", urlWithUserAndPass) @@ -2057,4 +2067,18 @@ class JDBCSuite extends QueryTest with SharedSparkSession { val df = sql("SELECT * FROM composite_name WHERE `last name` = 'smith'") assert(df.collect.toSet === Set(Row("smith", 1))) } + + test("SPARK-45475: saving a table via JDBC should work with observe API") { + val tableName = "test_table" + val namedObservation = Observation("named") + val observed_df = spark.range(100).observe( + namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) + + observed_df.write.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", tableName).save() + + val expected = Map("percentile_approx_val" -> 49) + assert(namedObservation.get === expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index ae0cfe17b11f5..5d2108d2b8fce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -185,6 +185,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel conn.prepareStatement("INSERT INTO \"test\".\"datetime\" VALUES " + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"address\" (email TEXT(32) NOT NULL)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc_def@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc%def@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc%_def@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc_%def@gmail.com')").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"address\" VALUES " + + "('abc_''%def@gmail.com')").executeUpdate() + conn.prepareStatement("CREATE TABLE \"test\".\"binary1\" (name TEXT(32),b BINARY(20))") .executeUpdate() val stmt = conn.prepareStatement("INSERT INTO \"test\".\"binary1\" VALUES (?, ?)") @@ -1096,7 +1109,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df3 = spark.table("h2.test.employee").filter($"name".startsWith("a")) checkFiltersRemoved(df3) - checkPushedInfo(df3, "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']") + checkPushedInfo(df3, raw"PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%' ESCAPE '\']") checkAnswer(df3, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false))) val df4 = spark.table("h2.test.employee").filter($"is_manager") @@ -1240,6 +1253,94 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df17, Seq(Row(6, "jen", 12000, 1200, true))) } + test("SPARK-38432: escape the single quote, _ and % for DS V2 pushdown") { + val df1 = spark.table("h2.test.address").filter($"email".startsWith("abc_")) + checkFiltersRemoved(df1) + checkPushedInfo(df1, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_%' ESCAPE '\']") + checkAnswer(df1, + Seq(Row("abc_%def@gmail.com"), Row("abc_'%def@gmail.com"), Row("abc_def@gmail.com"))) + + val df2 = spark.table("h2.test.address").filter($"email".startsWith("abc%")) + checkFiltersRemoved(df2) + checkPushedInfo(df2, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\%%' ESCAPE '\']") + checkAnswer(df2, Seq(Row("abc%_def@gmail.com"), Row("abc%def@gmail.com"))) + + val df3 = spark.table("h2.test.address").filter($"email".startsWith("abc%_")) + checkFiltersRemoved(df3) + checkPushedInfo(df3, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\%\_%' ESCAPE '\']") + checkAnswer(df3, Seq(Row("abc%_def@gmail.com"))) + + val df4 = spark.table("h2.test.address").filter($"email".startsWith("abc_%")) + checkFiltersRemoved(df4) + checkPushedInfo(df4, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_\%%' ESCAPE '\']") + checkAnswer(df4, Seq(Row("abc_%def@gmail.com"))) + + val df5 = spark.table("h2.test.address").filter($"email".startsWith("abc_'%")) + checkFiltersRemoved(df5) + checkPushedInfo(df5, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE 'abc\_''\%%' ESCAPE '\']") + checkAnswer(df5, Seq(Row("abc_'%def@gmail.com"))) + + val df6 = spark.table("h2.test.address").filter($"email".endsWith("_def@gmail.com")) + checkFiltersRemoved(df6) + checkPushedInfo(df6, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_def@gmail.com' ESCAPE '\']") + checkAnswer(df6, Seq(Row("abc%_def@gmail.com"), Row("abc_def@gmail.com"))) + + val df7 = spark.table("h2.test.address").filter($"email".endsWith("%def@gmail.com")) + checkFiltersRemoved(df7) + checkPushedInfo(df7, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\%def@gmail.com' ESCAPE '\']") + checkAnswer(df7, + Seq(Row("abc%def@gmail.com"), Row("abc_%def@gmail.com"), Row("abc_'%def@gmail.com"))) + + val df8 = spark.table("h2.test.address").filter($"email".endsWith("%_def@gmail.com")) + checkFiltersRemoved(df8) + checkPushedInfo(df8, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\%\_def@gmail.com' ESCAPE '\']") + checkAnswer(df8, Seq(Row("abc%_def@gmail.com"))) + + val df9 = spark.table("h2.test.address").filter($"email".endsWith("_%def@gmail.com")) + checkFiltersRemoved(df9) + checkPushedInfo(df9, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_\%def@gmail.com' ESCAPE '\']") + checkAnswer(df9, Seq(Row("abc_%def@gmail.com"))) + + val df10 = spark.table("h2.test.address").filter($"email".endsWith("_'%def@gmail.com")) + checkFiltersRemoved(df10) + checkPushedInfo(df10, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%\_''\%def@gmail.com' ESCAPE '\']") + checkAnswer(df10, Seq(Row("abc_'%def@gmail.com"))) + + val df11 = spark.table("h2.test.address").filter($"email".contains("c_d")) + checkFiltersRemoved(df11) + checkPushedInfo(df11, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_d%' ESCAPE '\']") + checkAnswer(df11, Seq(Row("abc_def@gmail.com"))) + + val df12 = spark.table("h2.test.address").filter($"email".contains("c%d")) + checkFiltersRemoved(df12) + checkPushedInfo(df12, raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\%d%' ESCAPE '\']") + checkAnswer(df12, Seq(Row("abc%def@gmail.com"))) + + val df13 = spark.table("h2.test.address").filter($"email".contains("c%_d")) + checkFiltersRemoved(df13) + checkPushedInfo(df13, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\%\_d%' ESCAPE '\']") + checkAnswer(df13, Seq(Row("abc%_def@gmail.com"))) + + val df14 = spark.table("h2.test.address").filter($"email".contains("c_%d")) + checkFiltersRemoved(df14) + checkPushedInfo(df14, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_\%d%' ESCAPE '\']") + checkAnswer(df14, Seq(Row("abc_%def@gmail.com"))) + + val df15 = spark.table("h2.test.address").filter($"email".contains("c_'%d")) + checkFiltersRemoved(df15) + checkPushedInfo(df15, + raw"PushedFilters: [EMAIL IS NOT NULL, EMAIL LIKE '%c\_''\%d%' ESCAPE '\']") + checkAnswer(df15, Seq(Row("abc_'%def@gmail.com"))) + } + test("scan with filter push-down with ansi mode") { Seq(false, true).foreach { ansiMode => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { @@ -1325,10 +1426,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df6, ansiMode) val expectedPlanFragment6 = if (ansiMode) { "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, " + - "CAST(BONUS AS string) LIKE '%30%', CAST(DEPT AS byte) > 1, " + + raw"CAST(BONUS AS string) LIKE '%30%' ESCAPE '\', CAST(DEPT AS byte) > 1, " + "CAST(DEPT AS short) > 1, CAST(BONUS AS decimal(20,2)) > 1200.00]" } else { - "PushedFilters: [BONUS IS NOT NULL, DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%']" + "PushedFilters: [BONUS IS NOT NULL, " + + raw"DEPT IS NOT NULL, CAST(BONUS AS string) LIKE '%30%' ESCAPE '\']" } checkPushedInfo(df6, expectedPlanFragment6) checkAnswer(df6, Seq(Row(2, "david", 10000, 1300, true))) @@ -1538,8 +1640,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), - Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), + Seq(Row("test", "address", false), Row("test", "people", false), + Row("test", "empty_table", false), Row("test", "employee", false), + Row("test", "item", false), Row("test", "dept", false), Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false), Row("test", "datetime", false), Row("test", "binary1", false))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 75f440caefc34..1954cce7fdc2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -650,6 +650,19 @@ abstract class FileStreamSinkSuite extends StreamTest { } } } + + test("SPARK-48991: Move path initialization into try-catch block") { + val logAppender = new LogAppender("Assume no metadata directory.") + Seq(null, "", "file:tmp").foreach { path => + withLogAppender(logAppender) { + assert(!FileStreamSink.hasMetadata(Seq(path), spark.sessionState.newHadoopConf(), conf)) + } + + assert(logAppender.loggingEvents.map(_.getMessage.getFormattedMessage).contains( + "Assume no metadata directory. Error while looking for metadata directory in the path:" + + s" $path.")) + } + } } object PendingCommitFilesTrackingManifestFileCommitProtocol { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index 20fb17fe6ec2a..b3a4b8e3d3cbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -366,91 +366,101 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { ) } - test("applyInPandasWithState - streaming w/ event time timeout + watermark") { - assume(shouldTestPandasUDFs) + Seq(true, false).map { ifUseDateTimeType => + test("applyInPandasWithState - streaming w/ event time timeout + watermark " + + s"ifUseDateTimeType=$ifUseDateTimeType") { + assume(shouldTestPandasUDFs) - // timestamp_seconds assumes the base timezone is UTC. However, the provided function - // localizes it. Therefore, this test assumes the timezone is in UTC - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { - val pythonScript = - """ - |import calendar - |import os - |import datetime - |import pandas as pd - |from pyspark.sql.types import StructType, StringType, StructField, IntegerType - | - |tpe = StructType([ - | StructField("key", StringType()), - | StructField("maxEventTimeSec", IntegerType())]) - | - |def func(key, pdf_iter, state): - | assert state.getCurrentProcessingTimeMs() >= 0 - | assert state.getCurrentWatermarkMs() >= -1 - | - | timeout_delay_sec = 5 - | if state.hasTimedOut: - | state.remove() - | yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) - | else: - | m = state.getOption - | if m is None: - | max_event_time_sec = 0 - | else: - | max_event_time_sec = m[0] - | - | for pdf in pdf_iter: - | pser = pdf.eventTime.apply( - | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) - | max_event_time_sec = int(max(pser.max(), max_event_time_sec)) - | - | state.update((max_event_time_sec,)) - | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec - | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) - | yield pd.DataFrame({'key': [key[0]], - | 'maxEventTimeSec': [max_event_time_sec]}) - |""".stripMargin - val pythonFunc = TestGroupedMapPandasUDFWithState( - name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + val timeoutMs = if (ifUseDateTimeType) { + "datetime.datetime.fromtimestamp(timeout_timestamp_sec)" + } else { + "timeout_timestamp_sec * 1000" + } - val inputData = MemoryStream[(String, Int)] - val inputDataDF = - inputData.toDF.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) - val outputStructType = StructType( - Seq( - StructField("key", StringType), - StructField("maxEventTimeSec", IntegerType))) - val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) - val result = - inputDataDF - .withWatermark("eventTime", "10 seconds") - .groupBy("key") - .applyInPandasWithState( - pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF], - outputStructType, - stateStructType, - "Update", - "EventTimeTimeout") + val pythonScript = + s""" + |import calendar + |import os + |import datetime + |import pandas as pd + |from pyspark.sql.types import StructType, StringType, StructField, IntegerType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("maxEventTimeSec", IntegerType())]) + | + |def func(key, pdf_iter, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | assert state.getCurrentWatermarkMs() >= -1 + | + | timeout_delay_sec = 5 + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) + | else: + | m = state.getOption + | if m is None: + | max_event_time_sec = 0 + | else: + | max_event_time_sec = m[0] + | + | for pdf in pdf_iter: + | pser = pdf.eventTime.apply( + | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) + | max_event_time_sec = int(max(pser.max(), max_event_time_sec)) + | + | state.update((max_event_time_sec,)) + | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec + | state.setTimeoutTimestamp($timeoutMs) + | yield pd.DataFrame({'key': [key[0]], + | 'maxEventTimeSec': [max_event_time_sec]}) + |""".stripMargin.format("") + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) - testStream(result, Update)( - StartStream(), + val inputData = MemoryStream[(String, Int)] + val inputDataDF = + inputData.toDF().select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("maxEventTimeSec", IntegerType))) + val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) + val result = + inputDataDF + .withWatermark("eventTime", "10 seconds") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "EventTimeTimeout") - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), - // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. - CheckNewAnswer(("a", 15)), // Output = max event time of a + testStream(result, Update)( + StartStream(), - AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckNewAnswer(), // No output as data should get filtered by watermark + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a - AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" - CheckNewAnswer(("a", 15)), // Max event time is still the same - // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. - // Watermark is still 5 as max event time for all data is still 15. + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark - AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" - // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. - CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 - ) + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is + // 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index fb5445ae436a1..0149e95586499 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -878,6 +878,60 @@ class MultiStatefulOperatorsSuite testOutputWatermarkInJoin(join3, input1, -40L * 1000 - 1) } + test("SPARK-49829 time window agg per each source followed by stream-stream join") { + val inputStream1 = MemoryStream[Long] + val inputStream2 = MemoryStream[Long] + + val df1 = inputStream1.toDF() + .selectExpr("value", "timestamp_seconds(value) AS ts") + .withWatermark("ts", "5 seconds") + + val df2 = inputStream2.toDF() + .selectExpr("value", "timestamp_seconds(value) AS ts") + .withWatermark("ts", "5 seconds") + + val df1Window = df1.groupBy( + window($"ts", "10 seconds") + ).agg(sum("value").as("sum_df1")) + + val df2Window = df2.groupBy( + window($"ts", "10 seconds") + ).agg(sum("value").as("sum_df2")) + + val joined = df1Window.join(df2Window, "window", "inner") + .selectExpr("CAST(window.end AS long) AS window_end", "sum_df1", "sum_df2") + + // The test verifies the case where both sides produce input as time window (append mode) + // for stream-stream join having join condition for equality of time window. + // Inputs are produced into stream-stream join when the time windows are completed, meaning + // they will be evicted in this batch for stream-stream join as well. (NOTE: join condition + // does not delay the state watermark in stream-stream join). + // Before SPARK-49829, left side does not add the input to state store if it's going to evict + // in this batch, which breaks the match between input from left side and input from right + // side for this batch. + testStream(joined)( + MultiAddData( + (inputStream1, Seq(1L, 2L, 3L, 4L, 5L)), + (inputStream2, Seq(5L, 6L, 7L, 8L, 9L)) + ), + // watermark: 5 - 5 = 0 + CheckNewAnswer(), + MultiAddData( + (inputStream1, Seq(11L, 12L, 13L, 14L, 15L)), + (inputStream2, Seq(15L, 16L, 17L, 18L, 19L)) + ), + // watermark: 15 - 5 = 10 (windows for [0, 10) are completed) + // Before SPARK-49829, the test fails because this row is not produced. + CheckNewAnswer((10L, 15L, 35L)), + MultiAddData( + (inputStream1, Seq(100L)), + (inputStream2, Seq(101L)) + ), + // watermark: 100 - 5 = 95 (windows for [0, 20) are completed) + CheckNewAnswer((20L, 65L, 85L)) + ) + } + private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 57ced748cd9f0..07837f5c06473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.StreamExecution trait StateStoreMetricsTest extends StreamTest { @@ -24,6 +26,8 @@ trait StateStoreMetricsTest extends StreamTest { private var lastCheckedRecentProgressIndex = -1 private var lastQuery: StreamExecution = null + override val streamingTimeout = 120.seconds + override def beforeEach(): Unit = { super.beforeEach() lastCheckedRecentProgressIndex = -1 @@ -106,7 +110,7 @@ trait StateStoreMetricsTest extends StreamTest { AssertOnQuery(s"Check operator progress metrics: operatorName = $operatorName, " + s"numShufflePartitions = $numShufflePartitions, " + s"numStateStoreInstances = $numStateStoreInstances") { q => - eventually(timeout(streamingTimeout)) { + eventually(timeout(streamingTimeout), interval(200.milliseconds)) { val (progressesSinceLastCheck, lastCheckedProgressIndex, numStateOperators) = retrieveProgressesSinceLastCheck(q) assert(operatorIndex < numStateOperators, s"Invalid operator Index: $operatorIndex") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index cb7995abcd092..69e404a473834 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -809,6 +809,13 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with case (key, None) => sparkSession.conf.unset(key) } sparkSession.streams.removeListener(listener) + // The state store is stopped here to unload all state stores and terminate all maintenance + // threads. It is necessary because the temp directory used by the checkpoint directory + // may be deleted soon after, and the maintenance thread may see unexpected error and + // cause unexpected behavior. Doing it after a test finishes might be too late because + // sometimes the checkpoint directory is under `withTempDir`, and in this case the temp + // directory is deleted before the test finishes. + StateStore.stop() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala index 595fc1cb9cea8..9a02ab3df7dd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationWithinWatermarkSuite.scala @@ -199,4 +199,25 @@ class StreamingDeduplicationWithinWatermarkSuite extends StateStoreMetricsTest { ) } } + + test("SPARK-46676: canonicalization of StreamingDeduplicateWithinWatermarkExec should work") { + withTempDir { checkpoint => + val dedupeInputData = MemoryStream[(String, Int)] + val dedupe = dedupeInputData.toDS() + .withColumn("eventTime", timestamp_seconds($"_2")) + .withWatermark("eventTime", "10 second") + .dropDuplicatesWithinWatermark("_1") + .select($"_1", $"eventTime".cast("long").as[Long]) + + testStream(dedupe, Append)( + StartStream(checkpointLocation = checkpoint.getCanonicalPath), + AddData(dedupeInputData, "a" -> 1), + CheckNewAnswer("a" -> 1), + Execute { q => + // This threw out error before SPARK-46676. + q.lastExecution.executedPlan.canonicalized + } + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 3e1bc57dfa245..aad91601758ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1417,6 +1417,56 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { ) } } + + test("SPARK-49829 left-outer join, input being unmatched is between WM for late event and " + + "WM for eviction") { + + withTempDir { checkpoint => + // This config needs to be set, otherwise no-data batch will be triggered and after + // no-data batch, WM for late event and WM for eviction would be same. + withSQLConf(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false") { + val memoryStream1 = MemoryStream[(String, Int)] + val memoryStream2 = MemoryStream[(String, Int)] + + val data1 = memoryStream1.toDF() + .selectExpr("_1 AS key", "timestamp_seconds(_2) AS eventTime") + .withWatermark("eventTime", "0 seconds") + val data2 = memoryStream2.toDF() + .selectExpr("_1 AS key", "timestamp_seconds(_2) AS eventTime") + .withWatermark("eventTime", "0 seconds") + + val joinedDf = data1.join(data2, Seq("key", "eventTime"), "leftOuter") + .selectExpr("key", "CAST(eventTime AS long) AS eventTime") + + testStream(joinedDf)( + StartStream(checkpointLocation = checkpoint.getCanonicalPath), + // batch 0 + // WM: late record = 0, eviction = 0 + MultiAddData( + (memoryStream1, Seq(("a", 1), ("b", 2))), + (memoryStream2, Seq(("b", 2), ("c", 1))) + ), + CheckNewAnswer(("b", 2)), + // state rows + // left: ("a", 1), ("b", 2) + // right: ("b", 2), ("c", 1) + // batch 1 + // WM: late record = 0, eviction = 2 + // Before Spark introduces multiple stateful operator, WM for late record was same as + // WM for eviction, hence ("d", 1) was treated as late record. + // With the multiple state operator, ("d", 1) is added in batch 1 but also evicted in + // batch 1. Note that the eviction is happening with state watermark: for this join, + // state watermark = state eviction under join condition. Before SPARK-49829, this + // wasn't producing unmatched row, and it is fixed. + AddData(memoryStream1, ("d", 1)), + CheckNewAnswer(("a", 1), ("d", 1)) + // state rows + // left: none + // right: none + ) + } + } + } } @SlowSQLTest @@ -1824,4 +1874,119 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { assertNumStateRows(total = 9, updated = 4) ) } + + test("SPARK-49829 two chained stream-stream left outer joins among three input streams") { + withSQLConf(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false") { + val memoryStream1 = MemoryStream[(Long, Int)] + val memoryStream2 = MemoryStream[(Long, Int)] + val memoryStream3 = MemoryStream[(Long, Int)] + + val data1 = memoryStream1.toDF() + .selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v1") + .withWatermark("eventTime", "0 seconds") + val data2 = memoryStream2.toDF() + .selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v2") + .withWatermark("eventTime", "0 seconds") + val data3 = memoryStream3.toDF() + .selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v3") + .withWatermark("eventTime", "0 seconds") + + val join = data1 + .join(data2, Seq("eventTime"), "leftOuter") + .join(data3, Seq("eventTime"), "leftOuter") + .selectExpr("CAST(eventTime AS long) AS eventTime", "v1", "v2", "v3") + + testStream(join)( + // batch 0 + // WM: late event = 0, eviction = 0 + MultiAddData( + (memoryStream1, Seq((20L, 1))), + (memoryStream2, Seq((20L, 1))), + (memoryStream3, Seq((20L, 1))) + ), + CheckNewAnswer((20, 1, 1, 1)), + // state rows + // 1st join + // left: (20, 1) + // right: (20, 1) + // 2nd join + // left: (20, 1, 1) + // right: (20, 1) + // batch 1 + // WM: late event = 0, eviction = 20 + MultiAddData( + (memoryStream1, Seq((21L, 2))), + (memoryStream2, Seq((21L, 2))) + ), + CheckNewAnswer(), + // state rows + // 1st join + // left: (21, 2) + // right: (21, 2) + // 2nd join + // left: (21, 2, 2) + // right: none + // batch 2 + // WM: late event = 20, eviction = 20 (slowest: inputStream3) + MultiAddData( + (memoryStream1, Seq((22L, 3))), + (memoryStream3, Seq((22L, 3))) + ), + CheckNewAnswer(), + // state rows + // 1st join + // left: (21, 2), (22, 3) + // right: (21, 2) + // 2nd join + // left: (21, 2, 2) + // right: (22, 3) + // batch 3 + // WM: late event = 20, eviction = 21 (slowest: inputStream2) + AddData(memoryStream1, (23L, 4)), + CheckNewAnswer(Row(21, 2, 2, null)), + // state rows + // 1st join + // left: (22, 3), (23, 4) + // right: none + // 2nd join + // left: none + // right: (22, 3) + // batch 4 + // WM: late event = 21, eviction = 21 (slowest: inputStream2) + MultiAddData( + (memoryStream1, Seq((24L, 5))), + (memoryStream2, Seq((24L, 5))), + (memoryStream3, Seq((24L, 5))) + ), + CheckNewAnswer(Row(24, 5, 5, 5)), + // state rows + // 1st join + // left: (22, 3), (23, 4), (24, 5) + // right: (24, 5) + // 2nd join + // left: (24, 5, 5) + // right: (22, 3), (24, 5) + // batch 5 + // WM: late event = 21, eviction = 24 + // just trigger a new batch with arbitrary data as the original test relies on no-data + // batch, and we need to check with remaining unmatched outputs + AddData(memoryStream1, (100L, 6)), + // Before SPARK-49829, the test fails because (23, 4, null, null) wasn't produced. + // (The assertion of state for left inputs & right inputs weren't included on the test + // before SPARK-49829.) + CheckNewAnswer(Row(22, 3, null, 3), Row(23, 4, null, null)) + ) + + /* + // The collection of the above new answers is the same with below in original test: + val expected = Array( + Row(Timestamp.valueOf("2024-02-10 10:20:00"), 1, 1, 1), + Row(Timestamp.valueOf("2024-02-10 10:21:00"), 2, 2, null), + Row(Timestamp.valueOf("2024-02-10 10:22:00"), 3, null, 3), + Row(Timestamp.valueOf("2024-02-10 10:23:00"), 4, null, null), + Row(Timestamp.valueOf("2024-02-10 10:24:00"), 5, 5, 5), + ) + */ + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala new file mode 100644 index 0000000000000..f651bfb7f3c72 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.sql.Timestamp + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.{count, expr, lit, timestamp_seconds, window} +import org.apache.spark.sql.internal.SQLConf + +/** + * This test ensures that any optimizations done by Spark SQL optimizer are + * correct for Streaming queries. + */ +class StreamingQueryOptimizationCorrectnessSuite extends StreamTest { + import testImplicits._ + + test("streaming Union with literal produces correct results") { + val inputStream1 = MemoryStream[Int] + val ds1 = inputStream1 + .toDS() + .withColumn("name", lit("ds1")) + .withColumn("count", $"value") + .select("name", "count") + + val inputStream2 = MemoryStream[Int] + val ds2 = inputStream2 + .toDS() + .withColumn("name", lit("ds2")) + .withColumn("count", $"value") + .select("name", "count") + + val result = + ds1.union(ds2) + .groupBy("name") + .count() + + testStream(result, OutputMode.Complete())( + AddData(inputStream1, 1), + ProcessAllAvailable(), + AddData(inputStream2, 1), + ProcessAllAvailable(), + CheckNewAnswer(Row("ds1", 1), Row("ds2", 1)) + ) + } + + test("streaming aggregate with literal and watermark after literal column" + + " produces correct results on query change") { + withTempDir { dir => + val inputStream1 = MemoryStream[Timestamp] + val ds1 = inputStream1 + .toDS() + .withColumn("name", lit("ds1")) + .withColumn("ts", $"value") + .withWatermark("ts", "1 minutes") + .select("name", "ts") + + val result = + ds1.groupBy("name").count() + + testStream(result, OutputMode.Complete())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, Timestamp.valueOf("2023-01-02 00:00:00")), + ProcessAllAvailable() + ) + + val ds2 = inputStream1 + .toDS() + .withColumn("name", lit("ds2")) + .withColumn("ts", $"value") + .withWatermark("ts", "1 minutes") + .select("name", "ts") + + val result2 = + ds2.groupBy("name").count() + + testStream(result2, OutputMode.Complete())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, Timestamp.valueOf("2023-01-03 00:00:00")), + ProcessAllAvailable(), + CheckNewAnswer(Row("ds1", 1), Row("ds2", 1)), + AddData(inputStream1, Timestamp.valueOf("2023-01-04 00:00:00")), + ProcessAllAvailable(), + CheckNewAnswer(Row("ds1", 1), Row("ds2", 2)) + ) + } + } + + test("streaming aggregate with literal and watermark before literal column" + + " produces correct results on query change") { + withTempDir { dir => + val inputStream1 = MemoryStream[Timestamp] + val ds1 = inputStream1 + .toDS() + .withColumn("ts", $"value") + .withWatermark("ts", "1 minutes") + .withColumn("name", lit("ds1")) + .select("name", "ts") + + val result = + ds1.groupBy("name").count() + + testStream(result, OutputMode.Complete())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, Timestamp.valueOf("2023-01-02 00:00:00")), + ProcessAllAvailable() + ) + + val ds2 = inputStream1 + .toDS() + .withColumn("ts", $"value") + .withWatermark("ts", "1 minutes") + .withColumn("name", lit("ds2")) + .select("name", "ts") + + val result2 = + ds2.groupBy("name").count() + + testStream(result2, OutputMode.Complete())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, Timestamp.valueOf("2023-01-03 00:00:00")), + ProcessAllAvailable(), + CheckNewAnswer(Row("ds1", 1), Row("ds2", 1)), + AddData(inputStream1, Timestamp.valueOf("2023-01-04 00:00:00")), + ProcessAllAvailable(), + CheckNewAnswer(Row("ds1", 1), Row("ds2", 2)) + ) + } + } + + test("streaming aggregate with literal" + + " produces correct results on query change") { + withTempDir { dir => + val inputStream1 = MemoryStream[Int] + val ds1 = inputStream1 + .toDS() + .withColumn("name", lit("ds1")) + .withColumn("count", $"value") + .select("name", "count") + + val result = + ds1.groupBy("name").count() + + testStream(result, OutputMode.Complete())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, 1), + ProcessAllAvailable() + ) + + val ds2 = inputStream1 + .toDS() + .withColumn("name", lit("ds2")) + .withColumn("count", $"value") + .select("name", "count") + + val result2 = + ds2.groupBy("name").count() + + testStream(result2, OutputMode.Complete())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, 1), + ProcessAllAvailable(), + CheckNewAnswer(Row("ds1", 1), Row("ds2", 1)) + ) + } + } + + test("stream stream join with literal" + + " produces correct results") { + withTempDir { dir => + import java.sql.Timestamp + val inputStream1 = MemoryStream[Int] + val inputStream2 = MemoryStream[Int] + + val ds1 = inputStream1 + .toDS() + .withColumn("name", lit(Timestamp.valueOf("2023-01-01 00:00:00"))) + .withWatermark("name", "1 minutes") + .withColumn("count1", lit(1)) + + val ds2 = inputStream2 + .toDS() + .withColumn("name", lit(Timestamp.valueOf("2023-01-02 00:00:00"))) + .withWatermark("name", "1 minutes") + .withColumn("count2", lit(2)) + + + val result = + ds1.join(ds2, "name", "full") + .select("name", "count1", "count2") + + testStream(result, OutputMode.Append())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, 1), + ProcessAllAvailable(), + AddData(inputStream2, 1), + ProcessAllAvailable(), + AddData(inputStream1, 2), + ProcessAllAvailable(), + AddData(inputStream2, 2), + ProcessAllAvailable(), + CheckNewAnswer() + ) + + // modify the query and update literal values for name + val ds3 = inputStream1 + .toDS() + .withColumn("name", lit(Timestamp.valueOf("2023-02-01 00:00:00"))) + .withWatermark("name", "1 minutes") + .withColumn("count1", lit(3)) + + val ds4 = inputStream2 + .toDS() + .withColumn("name", lit(Timestamp.valueOf("2023-02-02 00:00:00"))) + .withWatermark("name", "1 minutes") + .withColumn("count2", lit(4)) + + val result2 = + ds3.join(ds4, "name", "full") + .select("name", "count1", "count2") + + testStream(result2, OutputMode.Append())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(inputStream1, 1), + ProcessAllAvailable(), + AddData(inputStream2, 1), + ProcessAllAvailable(), + AddData(inputStream1, 2), + ProcessAllAvailable(), + AddData(inputStream2, 2), + ProcessAllAvailable(), + CheckNewAnswer( + Row(Timestamp.valueOf("2023-01-01 00:00:00"), + 1, null.asInstanceOf[java.lang.Integer]), + Row(Timestamp.valueOf("2023-01-01 00:00:00"), + 1, null.asInstanceOf[java.lang.Integer]), + Row(Timestamp.valueOf("2023-01-02 00:00:00"), + null.asInstanceOf[java.lang.Integer], 2), + Row(Timestamp.valueOf("2023-01-02 00:00:00"), + null.asInstanceOf[java.lang.Integer], 2) + ) + ) + } + } + + test("streaming SQL distinct usage with literal grouping" + + " key produces correct results") { + val inputStream1 = MemoryStream[Int] + val ds1 = inputStream1 + .toDS() + .withColumn("name", lit("ds1")) + .withColumn("count", $"value") + .select("name", "count") + + val inputStream2 = MemoryStream[Int] + val ds2 = inputStream2 + .toDS() + .withColumn("name", lit("ds2")) + .withColumn("count", $"value") + .select("name", "count") + + val result = + ds1.union(ds2) + .groupBy("name") + .as[String, (String, Int, Int)] + .keys + + testStream(result, OutputMode.Complete())( + AddData(inputStream1, 1), + ProcessAllAvailable(), + AddData(inputStream2, 1), + ProcessAllAvailable(), + CheckNewAnswer(Row("ds1"), Row("ds2")) + ) + } + + test("streaming window aggregation with literal time column" + + " key produces correct results") { + val inputStream1 = MemoryStream[Int] + val ds1 = inputStream1 + .toDS() + .withColumn("name", lit(Timestamp.valueOf("2023-01-01 00:00:00"))) + .withColumn("count", $"value") + .select("name", "count") + + val inputStream2 = MemoryStream[Int] + val ds2 = inputStream2 + .toDS() + .withColumn("name", lit(Timestamp.valueOf("2023-01-02 00:00:00"))) + .withColumn("count", $"value") + .select("name", "count") + + val result = + ds1.union(ds2) + .groupBy( + window($"name", "1 second", "1 second") + ) + .count() + + testStream(result, OutputMode.Complete())( + AddData(inputStream1, 1), + ProcessAllAvailable(), + AddData(inputStream2, 1), + ProcessAllAvailable(), + CheckNewAnswer( + Row( + Row(Timestamp.valueOf("2023-01-01 00:00:00"), Timestamp.valueOf("2023-01-01 00:00:01")), + 1), + Row( + Row(Timestamp.valueOf("2023-01-02 00:00:00"), Timestamp.valueOf("2023-01-02 00:00:01")), + 1)) + ) + } + + test("stream stream join with literals produces correct value") { + withTempDir { dir => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1 + .toDF() + .withColumn("key", $"value") + .withColumn("leftValue", lit(1)) + .select("key", "leftValue") + + val df2 = input2 + .toDF() + .withColumn("key", $"value") + .withColumn("rightValue", lit(2)) + .select("key", "rightValue") + + val result = df1 + .join(df2, "key") + .select("key", "leftValue", "rightValue") + + testStream(result, OutputMode.Append())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(input1, 1), + ProcessAllAvailable(), + AddData(input2, 1), + ProcessAllAvailable(), + CheckAnswer(Row(1, 1, 2)) + ) + } + } + + test("stream stream join with literals produces correct value on query change") { + withTempDir { dir => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1 + .toDF() + .withColumn("key", lit("key1")) + .withColumn("leftValue", lit(1)) + .select("key", "leftValue") + + val df2 = input2 + .toDF() + .withColumn("key", lit("key2")) + .withColumn("rightValue", lit(2)) + .select("key", "rightValue") + + val result = df1 + .join(df2, "key") + .select("key", "leftValue", "rightValue") + + testStream(result, OutputMode.Append())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(input1, 1), + ProcessAllAvailable(), + AddData(input2, 1), + ProcessAllAvailable() + ) + + val df3 = input1 + .toDF() + .withColumn("key", lit("key2")) + .withColumn("leftValue", lit(3)) + .select("key", "leftValue") + + val df4 = input2 + .toDF() + .withColumn("key", lit("key1")) + .withColumn("rightValue", lit(4)) + .select("key", "rightValue") + + val result2 = df3 + .join(df4, "key") + .select("key", "leftValue", "rightValue") + + testStream(result2, OutputMode.Append())( + StartStream(checkpointLocation = dir.getAbsolutePath), + AddData(input1, 1), + ProcessAllAvailable(), + AddData(input2, 1), + ProcessAllAvailable(), + CheckAnswer( + Row("key1", 1, 4), + Row("key2", 3, 2)) + ) + } + } + + test("SPARK-48267: regression test, stream-stream union followed by stream-batch join") { + withTempDir { dir => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF().withColumn("code", lit(1)) + val df2 = input2.toDF().withColumn("code", lit(null)) + + // NOTE: The column 'ref_code' is known to be non-nullable. + val batchDf = spark.range(1, 5).select($"id".as("ref_code")) + + val unionDf = df1.union(df2) + .join(batchDf, expr("code = ref_code")) + .select("value") + + testStream(unionDf)( + StartStream(checkpointLocation = dir.getAbsolutePath), + + AddData(input1, 1, 2, 3), + CheckNewAnswer(1, 2, 3), + + AddData(input2, 1, 2, 3), + // The test failed before SPARK-47305 - the test failed with below error message: + // org.apache.spark.sql.streaming.StreamingQueryException: Stream-stream join without + // equality predicate is not supported.; + // Join Inner + // :- StreamingDataSourceV2ScanRelation[value#3] MemoryStreamDataSource + // +- LocalRelation + // Note that LocalRelation is actually a batch source (Range) but due to + // a bug, it was incorrect marked to the streaming. SPARK-47305 fixed the bug. + CheckNewAnswer() + ) + } + } + + test("SPARK-48481: DISTINCT with empty stream source should retain AGGREGATE") { + def doTest(numExpectedStatefulOperatorsForOneEmptySource: Int): Unit = { + withTempView("tv1", "tv2") { + val inputStream1 = MemoryStream[Int] + val ds1 = inputStream1.toDS() + ds1.registerTempTable("tv1") + + val inputStream2 = MemoryStream[Int] + val ds2 = inputStream2.toDS() + ds2.registerTempTable("tv2") + + // DISTINCT is rewritten to AGGREGATE, hence an AGGREGATEs for each source + val unioned = spark.sql( + """ + | WITH u AS ( + | SELECT DISTINCT value AS value FROM tv1 + | ), v AS ( + | SELECT DISTINCT value AS value FROM tv2 + | ) + | SELECT value FROM u UNION ALL SELECT value FROM v + |""".stripMargin + ) + + testStream(unioned, OutputMode.Update())( + MultiAddData(inputStream1, 1, 1, 2)(inputStream2, 1, 1, 2), + CheckNewAnswer(1, 2, 1, 2), + Execute { qe => + val stateOperators = qe.lastProgress.stateOperators + // Aggregate should be "stateful" one + assert(stateOperators.length === 2) + stateOperators.zipWithIndex.foreach { case (op, id) => + assert(op.numRowsUpdated === 2, s"stateful OP ID: $id") + } + }, + AddData(inputStream2, 2, 2, 3), + // NOTE: this is probably far from expectation to have 2 as output given user intends + // deduplicate, but the behavior is still correct with rewritten node and output mode: + // Aggregate & Update mode. + // TODO: Probably we should disallow DISTINCT or rewrite to + // dropDuplicates(WithinWatermark) for streaming source? + CheckNewAnswer(2, 3), + Execute { qe => + val stateOperators = qe.lastProgress.stateOperators + // Aggregate should be "stateful" one + assert(stateOperators.length === numExpectedStatefulOperatorsForOneEmptySource) + val opWithUpdatedRows = stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0) + assert(opWithUpdatedRows.length === 1) + // If this were dropDuplicates, numRowsUpdated should have been 1. + assert(opWithUpdatedRows.head._1.numRowsUpdated === 2, + s"stateful OP ID: ${opWithUpdatedRows.head._2}") + }, + AddData(inputStream1, 4, 4, 5), + CheckNewAnswer(4, 5), + Execute { qe => + val stateOperators = qe.lastProgress.stateOperators + assert(stateOperators.length === numExpectedStatefulOperatorsForOneEmptySource) + val opWithUpdatedRows = stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0) + assert(opWithUpdatedRows.length === 1) + assert(opWithUpdatedRows.head._1.numRowsUpdated === 2, + s"stateful OP ID: ${opWithUpdatedRows.head._2}") + } + ) + } + } + + doTest(numExpectedStatefulOperatorsForOneEmptySource = 2) + + withSQLConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED.key -> "true") { + doTest(numExpectedStatefulOperatorsForOneEmptySource = 1) + } + } + + test("SPARK-49699: observe node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .observe("observation", count(lit(1)).as("rows")) + // Enforce PruneFilters to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + val observeRow = qe.lastExecution.observedMetrics.get("observation") + assert(observeRow.get.getAs[Long]("rows") == 3L) + } + ) + } + + test("SPARK-49699: watermark node is not pruned out from PruneFilters") { + // NOTE: The test actually passes without SPARK-49699, because of the trickiness of + // filter pushdown and PruneFilters. Unlike observe node, the `false` filter is pushed down + // below to watermark node, hence PruneFilters rule does not prune out watermark node even + // before SPARK-49699. Propagate empty relation does not also propagate emptiness into + // watermark node, so the node is retained. The test is added for preventing regression. + + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 second") + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + // If the watermark node is pruned out, this would be null. + assert(qe.lastProgress.eventTime.get("watermark") != null) + } + ) + } + + test("SPARK-49699: stateful operator node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .groupBy("value") + .count() + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df, OutputMode.Complete())( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + assert(qe.lastProgress.stateOperators.length == 1) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 4a6325eb06074..8565056cda6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LocalRelation} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.read.InputPartition @@ -1318,6 +1318,51 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-46062: streaming query reading from CTE, which refers to temp view from " + + "streaming source") { + val inputStream = MemoryStream[Int] + inputStream.toDF().createOrReplaceTempView("tv") + + val df = spark.sql( + """ + |WITH w as ( + | SELECT * FROM tv + |) + |SELECT value from w + |""".stripMargin) + + testStream(df)( + AddData(inputStream, 1, 2, 3), + CheckAnswer(1, 2, 3), + Execute { q => + var isStreamingForCteDef: Option[Boolean] = None + var isStreamingForCteRef: Option[Boolean] = None + + q.analyzedPlan.foreach { + case d: CTERelationDef => + assert(d.resolved, "The definition node must be resolved after analysis.") + isStreamingForCteDef = Some(d.isStreaming) + + case d: CTERelationRef => + assert(d.resolved, "The reference node must be marked as resolved after analysis.") + isStreamingForCteRef = Some(d.isStreaming) + + case _ => + } + + assert(isStreamingForCteDef.isDefined && isStreamingForCteRef.isDefined, + "Both definition and reference for CTE should be available in analyzed plan.") + + assert(isStreamingForCteDef.get, "Expected isStreaming=true for CTE definition, but " + + "isStreaming is set to false.") + + assert(isStreamingForCteDef === isStreamingForCteRef, + "isStreaming flag should be carried over from definition to reference, " + + s"definition: ${isStreamingForCteDef.get}, reference: ${isStreamingForCteRef.get}.") + } + ) + } + private def checkExceptionMessage(df: DataFrame): Unit = { withTempDir { outputDir => withTempDir { checkpointDir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 17348fe2dcbb5..b40f9210a686d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -1363,4 +1363,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } } } + + test("SPARK-39910: read files from Hadoop archives") { + val fileSchema = new StructType().add("str", StringType) + val harPath = testFile("test-data/test-archive.har") + .replaceFirst("file:/", "har:/") + + testRead(spark.read.schema(fileSchema).csv(s"$harPath/test.csv"), data, fileSchema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index dd55fcfe42cac..e937173a590f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -423,8 +423,9 @@ private[sql] trait SQLTestUtilsBase * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - spark.sessionState.catalog.setCurrentDatabase(db) - Utils.tryWithSafeFinally(f)(spark.sessionState.catalog.setCurrentDatabase("default")) + spark.sessionState.catalogManager.setCurrentNamespace(Array(db)) + Utils.tryWithSafeFinally(f)( + spark.sessionState.catalogManager.setCurrentNamespace(Array("default"))) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala index 658f79fc28942..c63c748953f1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.status.api.v1.sql import java.net.URL import java.text.SimpleDateFormat +import javax.servlet.http.HttpServletResponse import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods @@ -148,4 +149,12 @@ class SqlResourceWithActualMetricsSuite } } + test("SPARK-45291: Use unknown query execution id instead of no such app when id is invalid") { + val url = new URL(spark.sparkContext.ui.get.webUrl + + s"/api/v1/applications/${spark.sparkContext.applicationId}/sql/${Long.MaxValue}") + val (code, resultOpt, error) = getContentAndCode(url) + assert(code === HttpServletResponse.SC_NOT_FOUND) + assert(resultOpt.isEmpty) + assert(error.get === s"unknown query execution id: ${Long.MaxValue}") + } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 7222d49ecb020..5d2708dfdd714 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java index 782e47a6cd902..4b8d2cb1536cd 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/CookieSigner.java @@ -81,8 +81,7 @@ public String verifyAndExtract(String signedStr) { LOG.debug("Signature generated for " + rawValue + " inside verify is " + currentSignature); } if (!MessageDigest.isEqual(originalSignature.getBytes(), currentSignature.getBytes())) { - throw new IllegalArgumentException("Invalid sign, original = " + originalSignature + - " current = " + currentSignature); + throw new IllegalArgumentException("Invalid sign"); } return rawValue; } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index 8d77b238ff41f..e3316cef241c3 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -17,7 +17,6 @@ package org.apache.hive.service.auth; import java.io.IOException; -import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; @@ -85,18 +84,9 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - private static Field keytabFile = null; private static Method getKeytab = null; static { Class clz = UserGroupInformation.class; - try { - keytabFile = clz.getDeclaredField("keytabFile"); - keytabFile.setAccessible(true); - } catch (NoSuchFieldException nfe) { - LOG.debug("Cannot find private field \"keytabFile\" in class: " + - UserGroupInformation.class.getCanonicalName(), nfe); - keytabFile = null; - } try { getKeytab = clz.getDeclaredMethod("getKeytab"); @@ -347,9 +337,7 @@ public static boolean needUgiLogin(UserGroupInformation ugi, String principal, S private static String getKeytabFromUgi() { synchronized (UserGroupInformation.class) { try { - if (keytabFile != null) { - return (String) keytabFile.get(null); - } else if (getKeytab != null) { + if (getKeytab != null) { return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); } else { return null; diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/RowSetUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/RowSetUtils.scala index 9625021f392cb..43130bb204f6d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/RowSetUtils.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/RowSetUtils.scala @@ -52,11 +52,7 @@ object RowSetUtils { rows: Seq[Row], schema: Array[DataType], timeFormatters: TimeFormatters): TRowSet = { - var i = 0 - val rowSize = rows.length - val tRows = new java.util.ArrayList[TRow](rowSize) - while (i < rowSize) { - val row = rows(i) + val tRows = rows.map { row => val tRow = new TRow() var j = 0 val columnSize = row.length @@ -65,9 +61,8 @@ object RowSetUtils { tRow.addToColVals(columnValue) j += 1 } - i += 1 - tRows.add(tRow) - } + tRow + }.asJava new TRowSet(startRowOffSet, tRows) } @@ -136,8 +131,7 @@ object RowSetUtils { var i = 0 val rowSize = rows.length val values = new java.util.ArrayList[String](rowSize) - while (i < rowSize) { - val row = rows(i) + rows.foreach { row => nulls.set(i, row.isNullAt(ordinal)) val value = if (row.isNullAt(ordinal)) { "" @@ -159,8 +153,7 @@ object RowSetUtils { val size = rows.length val ret = new java.util.ArrayList[T](size) var idx = 0 - while (idx < size) { - val row = rows(idx) + rows.foreach { row => if (row.isNullAt(ordinal)) { nulls.set(idx, true) ret.add(idx, defaultVal) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index a9b46739fa665..47ec242c9da95 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.util.{Collections, Map => JMap} -import java.util.concurrent.{Executors, RejectedExecutionException, TimeUnit} +import java.util.concurrent.{Executors, RejectedExecutionException, ScheduledExecutorService, TimeUnit} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -60,6 +60,7 @@ private[hive] class SparkExecuteStatementOperation( queryTimeout } } + private var timeoutExecutor: ScheduledExecutorService = _ private val forceCancel = sqlContext.conf.getConf(SQLConf.THRIFTSERVER_FORCE_CANCEL) @@ -114,7 +115,7 @@ private[hive] class SparkExecuteStatementOperation( val offset = iter.getPosition val rows = iter.take(maxRows).toList log.debug(s"Returning result set with ${rows.length} rows from offsets " + - s"[${iter.getFetchStart}, ${offset}) with $statementId") + s"[${iter.getFetchStart}, ${iter.getPosition}) with $statementId") RowSetUtils.toTRowSet(offset, rows, dataTypes, getProtocolVersion, getTimeFormatters) } @@ -132,7 +133,7 @@ private[hive] class SparkExecuteStatementOperation( setHasResultSet(true) // avoid no resultset for async run if (timeout > 0) { - val timeoutExecutor = Executors.newSingleThreadScheduledExecutor() + timeoutExecutor = Executors.newSingleThreadScheduledExecutor() timeoutExecutor.schedule(new Runnable { override def run(): Unit = { try { @@ -306,6 +307,11 @@ private[hive] class SparkExecuteStatementOperation( if (statementId != null) { sqlContext.sparkContext.cancelJobGroup(statementId) } + // Shutdown the timeout thread if any, while cleaning up this operation + if (timeoutExecutor != null && + getStatus.getState != OperationState.TIMEDOUT && getStatus.getState.isTerminal) { + timeoutExecutor.shutdownNow() + } } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index af11b817d65b0..cb85993e5e099 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.session.SessionState import org.apache.logging.log4j.LogManager import org.apache.logging.log4j.core.Logger @@ -61,15 +63,18 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { spark.sql("show tables") }) for (msg <- Seq( - "show tables", "Could not connect to meta store", "org.apache.thrift.transport.TTransportException", "Connection refused")) { - exceptionString.contains(msg) + assert(exceptionString.contains(msg)) } } finally { Thread.currentThread().setContextClassLoader(originalClassLoader) spark.sparkContext.setLogLevel(originalLevel.toString) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + SessionState.detachSession() + Hive.closeCurrent() spark.stop() } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index d3a9a9f08411c..38dcd1d8b00af 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ import java.nio.charset.StandardCharsets -import java.sql.Timestamp -import java.util.Date import java.util.concurrent.CountDownLatch import scala.collection.JavaConverters._ @@ -145,11 +143,8 @@ class CliSuite extends SparkFunSuite { val lock = new Object def captureOutput(source: String)(line: String): Unit = lock.synchronized { - // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we - // add a timestamp to provide more diagnosis information. - val newLine = s"${new Timestamp(new Date().getTime)} - $source> $line" - log.info(newLine) - buffer += newLine + logInfo(s"$source> $line") + buffer += line if (line.startsWith("Spark master: ") && line.contains("Application Id: ")) { foundMasterAndApplicationIdMessage.trySuccess(()) @@ -198,7 +193,7 @@ class CliSuite extends SparkFunSuite { ThreadUtils.awaitResult(foundAllExpectedAnswers.future, timeoutForQuery) log.info("Found all expected output.") } catch { case cause: Throwable => - val message = + val message = lock.synchronized { s""" |======================= |CliSuite failure output @@ -212,6 +207,7 @@ class CliSuite extends SparkFunSuite { |End CliSuite failure output |=========================== """.stripMargin + } logError(message, cause) fail(message, cause) } finally { @@ -388,7 +384,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-11188 Analysis error reporting") { + testRetry("SPARK-11188 Analysis error reporting") { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( "select * from nonexistent_table;" -> "nonexistent_table" @@ -556,7 +552,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SparkException with root cause will be printStacktrace") { + testRetry("SparkException with root cause will be printStacktrace") { // If it is not in silent mode, will print the stacktrace runCliWithin( 1.minute, @@ -580,8 +576,8 @@ class CliSuite extends SparkFunSuite { runCliWithin(1.minute)("SELECT MAKE_DATE(-44, 3, 15);" -> "-0044-03-15") } - test("SPARK-33100: Ignore a semicolon inside a bracketed comment in spark-sql") { - runCliWithin(4.minute)( + testRetry("SPARK-33100: Ignore a semicolon inside a bracketed comment in spark-sql") { + runCliWithin(1.minute)( "/* SELECT 'test';*/ SELECT 'test';" -> "test", ";;/* SELECT 'test';*/ SELECT 'test';" -> "test", "/* SELECT 'test';*/;; SELECT 'test';" -> "test", @@ -628,8 +624,8 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-37555: spark-sql should pass last unclosed comment to backend") { - runCliWithin(5.minute)( + testRetry("SPARK-37555: spark-sql should pass last unclosed comment to backend") { + runCliWithin(1.minute)( // Only unclosed comment. "/* SELECT /*+ HINT() 4; */;".stripMargin -> "Syntax error at or near ';'", // Unclosed nested bracketed comment. @@ -642,7 +638,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-37694: delete [jar|file|archive] shall use spark sql processor") { + testRetry("SPARK-37694: delete [jar|file|archive] shall use spark sql processor") { runCliWithin(2.minute, errorResponses = Seq("ParseException"))( "delete jar dummy.jar;" -> "Syntax error at or near 'jar': missing 'FROM'.(line 1, pos 7)") } @@ -683,7 +679,7 @@ class CliSuite extends SparkFunSuite { SparkSQLEnv.stop() } - test("SPARK-39068: support in-memory catalog and running concurrently") { + testRetry("SPARK-39068: support in-memory catalog and running concurrently") { val extraConf = Seq("-c", s"${StaticSQLConf.CATALOG_IMPLEMENTATION.key}=in-memory") val cd = new CountDownLatch(2) def t: Thread = new Thread { @@ -703,7 +699,7 @@ class CliSuite extends SparkFunSuite { } // scalastyle:off line.size.limit - test("formats of error messages") { + testRetry("formats of error messages") { def check(format: ErrorMessageFormat.Value, errorMessage: String, silent: Boolean): Unit = { val expected = errorMessage.split(System.lineSeparator()).map("" -> _) runCliWithin( @@ -815,7 +811,6 @@ class CliSuite extends SparkFunSuite { s"spark.sql.catalog.$catalogName.url=jdbc:derby:memory:$catalogName;create=true" val catalogDriver = s"spark.sql.catalog.$catalogName.driver=org.apache.derby.jdbc.AutoloadedDriver" - val database = s"-database $catalogName.SYS" val catalogConfigs = Seq(catalogImpl, catalogDriver, catalogUrl, "spark.sql.catalogImplementation=in-memory") .flatMap(Seq("--conf", _)) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala index f5167a4ea7377..62d97772bcbc1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/HiveThriftServer2ListenerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.thriftserver.ui +import java.io.File import java.util.Properties import org.mockito.Mockito.{mock, RETURNS_SMART_NULLS} @@ -34,6 +35,15 @@ class HiveThriftServer2ListenerSuite extends SparkFunSuite with BeforeAndAfter { private var kvstore: ElementTrackingStore = _ + protected override def beforeAll(): Unit = { + val tmpDirName = System.getProperty("java.io.tmpdir") + val tmpDir = new File(tmpDirName) + if (!tmpDir.exists()) { + tmpDir.mkdirs() + } + super.beforeAll() + } + after { if (kvstore != null) { kvstore.close() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPageSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPageSuite.scala index d7e1852199639..1245e6740ebbe 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPageSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPageSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.thriftserver.ui +import java.io.File import java.util.{Calendar, Locale} import javax.servlet.http.HttpServletRequest @@ -34,6 +35,15 @@ class ThriftServerPageSuite extends SparkFunSuite with BeforeAndAfter { private var kvstore: ElementTrackingStore = _ + protected override def beforeAll(): Unit = { + val tmpDirName = System.getProperty("java.io.tmpdir") + val tmpDir = new File(tmpDirName) + if (!tmpDir.exists()) { + tmpDir.mkdirs() + } + super.beforeAll() + } + after { if (kvstore != null) { kvstore.close() diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bd323dc4b24e1..0467603c01cd0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,6 +41,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone private val originalAnsiMode = TestHive.conf.getConf(SQLConf.ANSI_ENABLED) + private val originalStoreAssignmentPolicy = + TestHive.conf.getConf(SQLConf.STORE_ASSIGNMENT_POLICY) private val originalCreateHiveTable = TestHive.conf.getConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT) @@ -76,6 +78,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) TestHive.setConf(SQLConf.ANSI_ENABLED, originalAnsiMode) + TestHive.setConf(SQLConf.STORE_ASSIGNMENT_POLICY, originalStoreAssignmentPolicy) TestHive.setConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT, originalCreateHiveTable) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index b04c7565f8a3b..9a313907eb130 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index dbfb4d65bd144..54b9db967d2dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.internal.SQLConf @@ -66,6 +67,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { val key = QualifiedTableName( // scalastyle:off caselocale + table.catalog.getOrElse(CatalogManager.SESSION_CATALOG_NAME).toLowerCase, table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, table.table.toLowerCase) // scalastyle:on caselocale @@ -191,8 +193,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileType: String, isWrite: Boolean): LogicalRelation = { val metastoreSchema = relation.tableMeta.schema - val tableIdentifier = - QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table) + val tableIdentifier = QualifiedTableName(relation.tableMeta.identifier.catalog.get, + relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val tablePath = new Path(relation.tableMeta.location) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2d0bcdff07151..08e02c90ebd63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -92,11 +92,11 @@ class HiveSessionStateBuilder( new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: + new DetermineTableStats(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = DetectAmbiguousSelfJoin +: - new DetermineTableStats(session) +: RelationConversions(catalog) +: QualifyLocationWithWarehouse(catalog) +: PreprocessTableCreation(catalog) +: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3da3d4a0eb5c8..c53a6c378d457 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -161,7 +161,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case InsertIntoStatement( r: HiveTableRelation, partSpec, _, query, overwrite, ifPartitionNotExists, _) - if DDLUtils.isHiveTable(r.tableMeta) => + if DDLUtils.isHiveTable(r.tableMeta) && query.resolved => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists, query.output.map(_.name)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 63f672b22bad2..60ff9ec42f29d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.metastore.api.{Database, EnvironmentContext, Function => HiveFunction, FunctionType, Index, MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.io.AcidUtils -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition, Table} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState @@ -1190,7 +1190,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { recordHiveCall() hive.getPartitionsByNames(table, partNames.asJava) } catch { - case ex: InvocationTargetException if ex.getCause.isInstanceOf[MetaException] => + case ex: HiveException if ex.getCause.isInstanceOf[MetaException] => logWarning("Caught Hive MetaException attempting to get partition metadata by " + "filter from client side. Falling back to fetching all partition metadata", ex) recordHiveCall() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index a28a0464e6ee9..18090b53e3c10 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -139,6 +139,10 @@ private[hive] object IsolatedClientLoader extends Logging { SparkSubmitUtils.buildIvySettings( Some(remoteRepos), ivyPath), + Some(SparkSubmitUtils.buildIvySettings( + Some(remoteRepos), + ivyPath, + useLocalM2AsCache = false)), transitive = true, exclusions = version.exclusions) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 9304074e866ca..eb69f23d2876a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -101,7 +101,6 @@ package object client { "org.pentaho:pentaho-aggdesigner-algorithm")) // Since HIVE-23980, calcite-core included in Hive package jar. - // For spark, only VersionsSuite currently creates a hive materialized view for testing. case object v2_3 extends HiveVersion("2.3.9", exclusions = Seq("org.apache.calcite:calcite-core", "org.apache.calcite:calcite-druid", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index 395ee86579e57..779562bed5b0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, PythonUDF, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation @@ -50,7 +50,12 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) filters: Seq[Expression], relation: HiveTableRelation): ExpressionSet = { val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output) + filters.filter { f => + f.deterministic && + !SubqueryExpression.hasSubquery(f) && + // Python UDFs might exist because this rule is applied before ``ExtractPythonUDFs``. + !f.exists(_.isInstanceOf[PythonUDF]) + }, relation.output) val partitionColumnSet = AttributeSet(relation.partitionCols) ExpressionSet( normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionColumnSet))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala index 094f8ba7a0f89..fc1c795a1aa1c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala @@ -129,7 +129,11 @@ class HiveGenericUDFEvaluator( override def returnType: DataType = inspectorToDataType(returnInspector) def setArg(index: Int, arg: Any): Unit = - deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg) + deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => arg) + + def setException(index: Int, exp: Throwable): Unit = { + deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => throw exp) + } override def doEvaluate(): Any = unwrapper(function.evaluate(deferredObjects)) } @@ -139,10 +143,10 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp extends DeferredObject with HiveInspectors { private val wrapper = wrapperFor(oi, dataType) - private var func: Any = _ - def set(func: Any): Unit = { + private var func: () => Any = _ + def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef] + override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 01684f52ab82b..0c8305b3ccb24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -136,7 +136,13 @@ private[hive] case class HiveGenericUDF( override def eval(input: InternalRow): Any = { children.zipWithIndex.foreach { - case (child, idx) => evaluator.setArg(idx, child.eval(input)) + case (child, idx) => + try { + evaluator.setArg(idx, child.eval(input)) + } catch { + case t: Throwable => + evaluator.setException(idx, t) + } } evaluator.evaluate() } @@ -157,10 +163,15 @@ private[hive] case class HiveGenericUDF( val setValues = evals.zipWithIndex.map { case (eval, i) => s""" - |if (${eval.isNull}) { - | $refEvaluator.setArg($i, null); - |} else { - | $refEvaluator.setArg($i, ${eval.value}); + |try { + | ${eval.code} + | if (${eval.isNull}) { + | $refEvaluator.setArg($i, null); + | } else { + | $refEvaluator.setArg($i, ${eval.value}); + | } + |} catch (Throwable t) { + | $refEvaluator.setException($i, t); |} |""".stripMargin } @@ -169,7 +180,6 @@ private[hive] case class HiveGenericUDF( val resultTerm = ctx.freshName("result") ev.copy(code = code""" - |${evals.map(_.code).mkString("\n")} |${setValues.mkString("\n")} |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(); |boolean ${ev.isNull} = $resultTerm == null; diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index a9314397dcf67..6a0b9686ffce1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -71,11 +71,10 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable SchemaMergeUtils.mergeSchemasInParallel( sparkSession, options, files, OrcFileOperator.readOrcSchemasInParallel) } else { - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles OrcFileOperator.readSchema( files.map(_.getPath.toString), Some(sparkSession.sessionState.newHadoopConfWithOptions(options)), - ignoreCorruptFiles + orcOptions.ignoreCorruptFiles ) } } @@ -146,7 +145,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + val ignoreCorruptFiles = + new OrcOptions(options, sparkSession.sessionState.conf).ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java new file mode 100644 index 0000000000000..242dbeaa63c94 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +public class UDFCatchException extends GenericUDF { + + @Override + public ObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException { + if (args.length != 1) { + throw new UDFArgumentException("Exactly one argument is expected."); + } + return PrimitiveObjectInspectorFactory.javaStringObjectInspector; + } + + @Override + public Object evaluate(GenericUDF.DeferredObject[] args) { + if (args == null) { + return null; + } + try { + return args[0].get(); + } catch (Exception e) { + return null; + } + } + + @Override + public String getDisplayString(String[] children) { + return null; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java new file mode 100644 index 0000000000000..5d6ff6ca40ae5 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFThrowException extends UDF { + public String evaluate(String data) { + return Integer.valueOf(data).toString(); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/connector/HiveSourceRowLevelOperationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/connector/HiveSourceRowLevelOperationSuite.scala new file mode 100644 index 0000000000000..344fdc21fe2cf --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/connector/HiveSourceRowLevelOperationSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.connector + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.connector.catalog.InMemoryRowLevelOperationTableCatalog +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveSourceRowLevelOperationSuite extends QueryTest with TestHiveSingleton + with BeforeAndAfter with SQLTestUtils { + + before { + spark.conf.set("spark.sql.catalog.cat", classOf[InMemoryRowLevelOperationTableCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") + } + + test("SPARK-45943: merge into using hive table without stats") { + val inMemCatNs = "cat.ns1" + val inMemCatTable = "in_mem_cat_table" + withTable("hive_table", s"$inMemCatNs.$inMemCatTable") { + // create hive table without stats + sql("create table hive_table(pk int, salary int, dep string)") + + sql( + s""" + |create table $inMemCatNs.$inMemCatTable ( + | pk INT NOT NULL, + | salary INT, + | dep STRING) + |PARTITIONED BY (dep) + | """.stripMargin) + + try { + // three-part naming is not supported in + // org.apache.spark.sql.hive.test.TestHiveQueryExecution.analyzed.{referencedTables} + sql(s"use $inMemCatNs") + sql( + s"""MERGE INTO $inMemCatTable t + |USING (SELECT pk, salary, dep FROM spark_catalog.default.hive_table) s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = s.salary + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + } finally { + sql("use spark_catalog.default") + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 2c5e2956f5f8a..2fad78e84f49d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -56,6 +56,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA try { // drop all databases, tables and functions after each test spark.sessionState.catalog.reset() + spark.sessionState.catalogManager.reset() } finally { super.afterEach() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 82b88ec9f35d6..4b85b37b6c2c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -161,28 +161,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd | SELECT key FROM gen_tmp ORDER BY key ASC; """.stripMargin) - test("multiple generators in projection") { - checkError( - exception = intercept[AnalysisException] { - sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) - - checkError( - exception = intercept[AnalysisException] { - sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) - } - createQueryTest("! operator", """ |SELECT a FROM ( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index d12ebae0f5fc7..f3be79f902294 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.{SparkException, SparkFiles, TestUtils} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.functions.{call_function, max} @@ -791,6 +792,28 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("SPARK-48845: GenericUDF catch exceptions from child UDFs") { + withTable("test_catch_exception") { + withUserDefinedFunction("udf_throw_exception" -> true, "udf_catch_exception" -> true) { + Seq("9", "9-1").toDF("a").write.saveAsTable("test_catch_exception") + sql("CREATE TEMPORARY FUNCTION udf_throw_exception AS " + + s"'${classOf[UDFThrowException].getName}'") + sql("CREATE TEMPORARY FUNCTION udf_catch_exception AS " + + s"'${classOf[UDFCatchException].getName}'") + Seq( + CodegenObjectFactoryMode.FALLBACK.toString, + CodegenObjectFactoryMode.NO_CODEGEN.toString + ).foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + val df = sql( + "SELECT udf_catch_exception(udf_throw_exception(a)) FROM test_catch_exception") + checkAnswer(df, Seq(Row("9"), Row(null))) + } + } + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9308d1eda146f..6160c3e5f6c65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2660,6 +2660,32 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi checkAnswer(df, Seq.empty[Row]) } } + + test("SPARK-46388: HiveAnalysis convert InsertIntoStatement to InsertIntoHiveTable " + + "iff child resolved") { + withTable("t") { + sql("CREATE TABLE t (a STRING)") + checkError( + exception = intercept[AnalysisException](sql("INSERT INTO t SELECT a*2 FROM t where b=1")), + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = None, + parameters = Map("objectName" -> "`b`", "proposal" -> "`a`"), + context = ExpectedContext( + fragment = "b", + start = 38, + stop = 38) ) + checkError( + exception = intercept[AnalysisException]( + sql("INSERT INTO t SELECT cast(a as short) FROM t where b=1")), + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = None, + parameters = Map("objectName" -> "`b`", "proposal" -> "`a`"), + context = ExpectedContext( + fragment = "b", + start = 51, + stop = 51)) + } + } } @SlowHiveTest diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index e52d9b639dc4f..ccf1d3df83efe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -381,4 +381,23 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { } } } + + test("SPARK-49094: ignoreCorruptFiles works for hive orc w/ mergeSchema off") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 1).toDF("a").write.orc(new Path(basePath, "foo=1").toString) + spark.range(0, 1).toDF("b").write.json(new Path(basePath, "foo=2").toString) + + withSQLConf( + SQLConf.IGNORE_CORRUPT_FILES.key -> "false", + SQLConf.ORC_IMPLEMENTATION.key -> "hive") { + Seq(true, false).foreach { mergeSchema => + checkAnswer(spark.read + .option("mergeSchema", value = mergeSchema) + .option("ignoreCorruptFiles", value = true) + .orc(basePath), Row(0L, 1)) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9284b35fb3e35..1d646f40b3e28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -543,6 +543,7 @@ private[hive] class TestHiveSparkSession( sharedState.cacheManager.clearCache() sharedState.loadedTables.clear() sessionState.catalog.reset() + sessionState.catalogManager.reset() metadataHive.reset() // HDFS root scratch dir requires the write all (733) permission. For each connecting user, diff --git a/streaming/pom.xml b/streaming/pom.xml index 55758d75ce54d..6cbccb39772c9 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 43aaa7e1eeaec..a8f55c8b4d641 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -52,7 +52,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { outputStreams.foreach(_.validateAtStart()) numReceivers = inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]]) inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)).toSeq + // scalastyle:off parvector new ParVector(inputStreams.toVector).foreach(_.start()) + // scalastyle:on parvector } } @@ -62,7 +64,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def stop(): Unit = { this.synchronized { + // scalastyle:off parvector new ParVector(inputStreams.toVector).foreach(_.stop()) + // scalastyle:on parvector } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index d1f9dfb791355..4e65bc75e4395 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -314,8 +314,10 @@ private[streaming] object FileBasedWriteAheadLog { val groupSize = taskSupport.parallelismLevel.max(8) source.grouped(groupSize).flatMap { group => + // scalastyle:off parvector val parallelCollection = new ParVector(group.toVector) parallelCollection.tasksupport = taskSupport + // scalastyle:on parvector parallelCollection.map(handler) }.flatten } diff --git a/tools/pom.xml b/tools/pom.xml index a63b2e1062dd8..f23f4a4b50559 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0 + 3.5.4-SNAPSHOT ../pom.xml diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index a46a7fbeec497..be863b52c2500 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -44,8 +44,15 @@ object GenerateMIMAIgnore { private def isPackagePrivate(sym: unv.Symbol) = !sym.privateWithin.fullName.startsWith("") - private def isPackagePrivateModule(moduleSymbol: unv.ModuleSymbol) = + private def isPackagePrivateModule(moduleSymbol: unv.ModuleSymbol) = try { !moduleSymbol.privateWithin.fullName.startsWith("") + } catch { + case e: Throwable => + // scalastyle:off println + println("[WARN] Unable to check module:" + moduleSymbol) + // scalastyle:on println + false + } /** * For every class checks via scala reflection if the class itself or contained members
    SQL metricsMeaningOperators
    number of output rows the number of output rows of the operator Aggregate operators, Join operators, Sample, Range, Scan operators, Filter, etc.
    data size the size of broadcast/shuffled/collected data of the operator BroadcastExchange, ShuffleExchange, Subquery