diff --git a/.bazelrc b/.bazelrc index 1538ddd284f616..bd71ab3016563d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -255,6 +255,14 @@ build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-1 build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.16-clang_config_cuda//crosstool:toolchain" +# Build with nvcc for CUDA and clang for host +build:nvcc_clang --config=cuda +# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang +build:nvcc_clang --action_env=TF_CUDA_CLANG="1" +build:nvcc_clang --action_env=TF_NVCC_CLANG="1" +build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc + + # Debug config build:dbg -c dbg # Only include debug info for files under tensorflow/, excluding kernels, to @@ -527,8 +535,8 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda +build:rbe_linux_cuda_nvcc --config=nvcc_clang build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 -build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" # TODO(kanglan): Remove rbe_win and rbe_win_py3* after b/289091160 is fixed build:rbe_win --config=rbe_base @@ -577,6 +585,7 @@ build:elinux_armhf --copt -mfp16-format=ieee # Load rc file written by ./configure. try-import %workspace%/.tf_configure.bazelrc +try-import %workspace%/xla_configure.bazelrc # Load rc file with user-specific options. try-import %workspace%/.bazelrc.user @@ -719,7 +728,7 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs # push to the cache. For macOS, use --config=tf_public_macos_cache -build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false +build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/january2024" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials # Public cache for macOS builds @@ -777,28 +786,38 @@ test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-os test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 PYCPP -test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 +# In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on +# Linux x86 so that we can use RBE. Since tests still need to run on the single +# host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. +# For testing purposes, we want to see the runtime performance of an +# experimental job that is build-only, i.e, we only build the test targets and +# do not run them. By prefixing the configs with "build", we can run both +# `bazel build` and `bazel test` commands with the same config as test configs +# inherit from build. +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP -test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled -test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test +build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP -test:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium -test:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +# These are defined as build configs so that we can run a build only job. See +# the note under "ARM64 PYCPP" for more details. +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP -test:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test -test:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test +build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test +build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # END TF TEST SUITE OPTIONS # START CROSS-COMPILE CONFIGS @@ -855,8 +874,12 @@ build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cr # RBE cross-compile configs for Darwin x86 build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base +build:rbe_cross_compile_macos_x86 --bes_upload_mode=fully_async test:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base # Increase the test timeout as tests often take longer on mac. test:rbe_cross_compile_macos_x86 --test_timeout=300,450,1200,3600 +# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) +build:rbe_cross_compile_macos_x86 --jobs=100 +test:rbe_cross_compile_macos_x86 --jobs=100 # END MACOS CROSS-COMPILE CONFIGS # END CROSS-COMPILE CONFIGS diff --git a/.bazelversion b/.bazelversion index b536fbc5061305..f3c238740e5bc3 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1,2 +1,2 @@ -6.1.0 +6.5.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index 15433f8f14be32..ddcc1b373e5c14 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -57,12 +57,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: ref: 'nightly' - name: Checkout repository for releases (skipped for nightly) if: ${{ github.event_name == 'push' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build and test pip wheel shell: bash run: | diff --git a/.github/workflows/arm-ci-extended-cpp.yml b/.github/workflows/arm-ci-extended-cpp.yml index e648297d37e789..2f9c67fb81eede 100644 --- a/.github/workflows/arm-ci-extended-cpp.yml +++ b/.github/workflows/arm-ci-extended-cpp.yml @@ -50,12 +50,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build binary and run C++ tests shell: bash run: | diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 01ce70ba82ecfa..db782d3cf35f30 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -51,12 +51,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build binary and run python tests on nightly for all python versions shell: bash run: | diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml index 3b07683008391d..7b3e8c6f24df49 100644 --- a/.github/workflows/arm-ci.yml +++ b/.github/workflows/arm-ci.yml @@ -47,7 +47,7 @@ jobs: shell: bash run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository - uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Build binary and run python tests shell: bash run: | diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index fb7366768436c5..a471d68b4fd2d7 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner/.github/workflows/osv-scanner-reusable.yml@main" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.6.2-beta1" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index 1b421effec8198..bdce23b94d02f1 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -106,13 +106,13 @@ jobs: map sigbuild-r2.14-clang-python3.10 2.14-python3.10 map sigbuild-r2.14-clang-python3.11 2.14-python3.11 # TF 2.16 - map sigbuild-r2.16 2.16-python3.9 + map sigbuild-r2.16 2.16-python3.11 map sigbuild-r2.16-python3.9 2.16-python3.9 map sigbuild-r2.16-python3.10 2.16-python3.10 map sigbuild-r2.16-python3.11 2.16-python3.11 map sigbuild-r2.16-python3.12 2.16-python3.12 # TF 2.16 + Clang (containers are the same, but env vars in configs.bzl are different) - map sigbuild-r2.16-clang 2.16-python3.9 + map sigbuild-r2.16-clang 2.16-python3.11 map sigbuild-r2.16-clang-python3.9 2.16-python3.9 map sigbuild-r2.16-clang-python3.10 2.16-python3.10 map sigbuild-r2.16-clang-python3.11 2.16-python3.11 diff --git a/.gitignore b/.gitignore index cebef4f590d47e..614cde3446a16f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ node_modules /.bazelrc.user /.tf_configure.bazelrc +/xla_configure.bazelrc /bazel-* /bazel_pip /tools/python_bin_path.sh diff --git a/RELEASE.md b/RELEASE.md index 784e2ac28ceea7..e75ca35b589d73 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,4 @@ -# Release 2.16.0 +# Release 2.17.0 ## TensorFlow @@ -9,11 +9,31 @@ * * -* `tf.summary.trace_on` now takes a `profiler_outdir` argument. This must be set - if `profiler` arg is set to `True`. - * `tf.summary.trace_export`'s `profiler_outdir` arg is now a no-op. Enabling - the profiler now requires setting `profiler_outdir` in `trace_on`. +### Known Caveats +* +* +* + +### Major Features and Improvements + +* +* + +### Bug Fixes and Other Changes + +* +* +* + +## Keras + + + +### Breaking Changes + +* +* ### Known Caveats @@ -26,6 +46,101 @@ * * +### Bug Fixes and Other Changes + +* +* +* + +* `tf.lite` + * Quantization for `FullyConnected` layer is switched from per-tensor to + per-channel scales for dynamic range quantization use case (`float32` + inputs / outputs and `int8` weights). The change enables new quantization + schema globally in the converter and inference engine. The new behaviour + can be disabled via experimental + flag `converter._experimental_disable_per_channel_quantization_for_dense_layers = True`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +, , , , , + +# Release 2.16.0 + +## TensorFlow + + + +* TensorFlow Windows Build: + + * Clang is now the default compiler to build TensorFlow CPU wheels on the + Windows Platform starting with this release. The currently supported + version is LLVM/clang 17. The official Wheels-published on PyPI will be + based on Clang; however, users retain the option to build wheels using + the MSVC compiler following the steps mentioned in + https://www.tensorflow.org/install/source_windows as has been the case + before + +### Breaking Changes + +* +* + +* `tf.summary.trace_on` now takes a `profiler_outdir` argument. This must be + set if `profiler` arg is set to `True`. + + * `tf.summary.trace_export`'s `profiler_outdir` arg is now a no-op. + Enabling the profiler now requires setting `profiler_outdir` in + `trace_on`. + +* `tf.estimator` + + * The tf.estimator API is removed. + +* Keras 3.0 will be the default Keras version. You may need to update your + script to use Keras 3.0. + +* Please refer to the new Keras documentation for Keras 3.0 + (https://keras.io/keras_3). + +* To continue using Keras 2.0, do the following. + +* 1. Install tf-keras via pip install tf-keras~=2.16 + + 1. To switch tf.keras to use Keras 2 (tf-keras), set the environment + variable TF_USE_LEGACY_KERAS=1 directly or in your python program by + import os;os.environ["TF_USE_LEGACY_KERAS"]=1. Please note that this + will set it for all packages in your Python runtime program + +* 1. Change import of keras from tensorflow as follows +* import tensorflow.keras as keras and import keras to import tf_keras as + keras +* **Apple Silicon users:** If you previously installed TensorFlow using + `pip install tensorflow-macos`, please update your installation method. Use + `pip install tensorflow` from now on. Starting with TF 2.17, the + `tensorflow-macos` package will no longer receive updates. + +### Known Caveats + +* +* +* + +* Full aarch64 Linux and Arm64 macOS wheels are now published to the + `tensorflow` pypi repository and no longer redirect to a separate package. + +### Major Features and Improvements + +* +* + +* Support for Python 3.12 has been added. +* [tensorflow-tpu](https://pypi.org/project/tensorflow-tpu/) package is now + available for easier TPU based installs. +* TensorFlow pip packages are now built with CUDA 12.3 and cuDNN 8.9.7 + + ### Bug Fixes and Other Changes * @@ -54,6 +169,21 @@ * Added `experimental_skip_slot_variables` (a boolean option) to skip restoring of optimizer slot variables in a checkpoint. +* `tf.saved_model.SaveOptions` + + * `SaveOptions` now takes a new argument called + `experimental_debug_stripper`. When enabled, this strips the debug nodes + from both the node defs and the function defs of the graph. Note that + this currently only strips the `Assert` nodes from the graph and + converts them into `NoOp`s instead. + +* `tf.data` + + * `tf.data` now has an `autotune_options.initial_parallelism` option to + control the initial parallelism setting used by autotune before the data + pipeline has started running. The default is 16. A lower value reduces + initial memory usage, while a higher value improves startup time. + ## Keras * `keras.layers.experimental.DynamicEmbedding` diff --git a/ci/official/README.md b/ci/official/README.md index d070af86cd8090..3c0181c5384392 100644 --- a/ci/official/README.md +++ b/ci/official/README.md @@ -45,7 +45,7 @@ cd tensorflow-git-dir # Here is a single-line example of running a script on Linux to build the # GPU version of TensorFlow for Python 3.12, using the public TF bazel cache and # a local build cache: -TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh +TFCI=py312,linux_x86_cuda,public_cache,disk_cache ci/official/wheel.sh # First, set your TFCI variable to choose the environment settings. # TFCI is a comma-separated list of filenames from the envs directory, which @@ -57,9 +57,10 @@ TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh # value in the "env_vars" list that you can choose to copy that environment. # Ex. 1: TFCI=py311,linux_x86_cuda,nightly_upload (nightly job) # Ex. 2: TFCI=py39,linux_x86,rbe (continuous job) -# Non-Googlers should replace "nightly_upload" or "rbe" with "multicache". -# Googlers should replace "nightly_upload" with "multicache" or "rbe", if -# you have set up your system to use RBE (see further below). +# Non-Googlers should replace "nightly_upload" or "rbe" with +# "public_cache,disk_cache". +# Googlers should replace "nightly_upload" with "public_cache,disk_cache" or +# "rbe", if you have set up your system to use RBE (see further below). # # Here is how to choose your TFCI value: # 1. A Python version must come first, because other scripts reference it. @@ -74,7 +75,9 @@ TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh # Ex. linux_x86_cuda -- x86_64 Linux platform, with Nvidia CUDA support # Ex. macos_arm64 -- arm64 MacOS platform # 3. Add modifiers. Some modifiers for local execution are: -# Ex. multicache -- Use a local cache combined with TF's public cache +# Ex. disk_cache -- Use a local cache +# Ex. public_cache -- Use TF's public cache (read-only) +# Ex. public_cache_push -- Use TF's public cache (read and write, Googlers only) # Ex. rbe -- Use RBE for faster builds (Googlers only; see below) # Ex. no_docker -- Disable docker on enabled platforms # See full examples below for more details on these. Some other modifiers are: @@ -94,7 +97,7 @@ TFCI=py312,linux_x86_cuda,multicache ci/official/wheel.sh # or tests passing incorrectly. # - Automatic LLVM updates are known to extend build time even with # the cache; this is unavoidable. -export TFCI=py311,linux_x86,multicache +export TFCI=py311,linux_x86,public_cache,disk_cache # Recommended: Configure Docker. (Linux only) # @@ -127,7 +130,7 @@ export TFCI=py311,linux_x86,multicache # it is only available to a limited set of internal TensorFlow developers. # # RBE is incompatible with local caching, so you must remove -# ci/official/envs/local_multicache from your $TFCI file. +# disk_cache, public_cache, and public_cache_push from your $TFCI file. # # To use RBE, you must first run `gcloud auth application-default login`, then: export TFCI=py311,linux_x86,rbe diff --git a/ci/official/any.sh b/ci/official/any.sh index 980bb3cfdf403a..dc1484b64dc9ea 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -29,7 +29,7 @@ # ./any.sh # # 3. DO THE SAME WITH A LOCAL CACHE OR RBE: -# export TF_ANY_EXTRA_ENV=ci/official/envs/local_multicache +# export TF_ANY_EXTRA_ENV=ci/official/envs/public_cache,ci/official/envs/disk_cache # ... # ./any.sh # or @@ -39,8 +39,8 @@ set -euxo pipefail cd "$(dirname "$0")/../../" # tensorflow/ # Any request that includes "nightly_upload" should just use the -# local multi-cache instead. -export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,multicache/')" +# local multi-cache (public read-only cache + disk cache) instead. +export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,public_cache,disk_cache/')" if [[ -n "${TF_ANY_EXTRA_ENV:-}" ]]; then export TFCI="$TFCI,$TF_ANY_EXTRA_ENV" fi diff --git a/ci/official/bisect.sh b/ci/official/bisect.sh index 4076a73b867e7a..7f18dd1460ff5b 100755 --- a/ci/official/bisect.sh +++ b/ci/official/bisect.sh @@ -34,6 +34,6 @@ # export TF_ANY_MODE=test set -euxo pipefail cd "$(dirname "$0")/../../" # tensorflow/ -export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,multicache/')" +export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,public_cache,disk_cache/')" git bisect start "$TF_BISECT_BAD" "$TF_BISECT_GOOD" git bisect run $TF_BISECT_SCRIPT diff --git a/ci/official/code_check_full.sh b/ci/official/code_check_full.sh index 5b1370b4f31e06..448fb82bf288b9 100755 --- a/ci/official/code_check_full.sh +++ b/ci/official/code_check_full.sh @@ -15,4 +15,4 @@ # ============================================================================== source "${BASH_SOURCE%/*}/utilities/setup.sh" -tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output "$TFCI_OUTPUT_DIR" +tfrun bats ./ci/official/utilities/code_check_full.bats --timing --output "$TFCI_OUTPUT_DIR" \ No newline at end of file diff --git a/ci/official/containers/linux_arm64/devel.packages.txt b/ci/official/containers/linux_arm64/devel.packages.txt index efbae80eefacee..a8a9cb442c8b0b 100644 --- a/ci/official/containers/linux_arm64/devel.packages.txt +++ b/ci/official/containers/linux_arm64/devel.packages.txt @@ -3,8 +3,6 @@ autoconf automake build-essential ca-certificates -# TODO(b/308399490) Remove CMake once dm-tree (Keras dependency) has 3.12 wheels -cmake llvm-17 clang-17 clang-format-12 diff --git a/ci/official/containers/linux_arm64/jax.requirements.txt b/ci/official/containers/linux_arm64/jax.requirements.txt index 878d229d0f237e..6211528986fdc0 100644 --- a/ci/official/containers/linux_arm64/jax.requirements.txt +++ b/ci/official/containers/linux_arm64/jax.requirements.txt @@ -24,4 +24,6 @@ scipy==1.9.2;python_version=="3.11" scipy==1.7.3;python_version<"3.11" ml_dtypes>=0.2.0 +# For using Python 3.11 with Bazel 6 (b/286090018) +lit ~= 17.0.2 diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default index d080a4566efe16..96d87423392541 100644 --- a/ci/official/envs/ci_default +++ b/ci/official/envs/ci_default @@ -58,6 +58,7 @@ TFCI_MACOS_UPGRADE_PYENV_ENABLE= TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= TFCI_NVIDIA_SMI_ENABLE= TFCI_OUTPUT_DIR= +TFCI_PYCPP_SWAP_TO_BUILD_ENABLE= TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS= TFCI_PYTHON_VERSION= TFCI_WHL_AUDIT_ENABLE= diff --git a/ci/official/envs/disk_cache b/ci/official/envs/disk_cache new file mode 100644 index 00000000000000..bd8ccfa0d95692 --- /dev/null +++ b/ci/official/envs/disk_cache @@ -0,0 +1,20 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Sourcing this enables local disk cache +if [[ $(uname -s) == "Darwin" ]]; then + echo "Please note that using disk cache on macOS is not recommended because the" + echo "cache can end up being pretty big and make the build process inefficient." +fi +TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --disk_cache=$TFCI_OUTPUT_DIR/cache" diff --git a/ci/official/envs/enable_pycpp_build b/ci/official/envs/enable_pycpp_build new file mode 100644 index 00000000000000..d7e0e5ea8065c3 --- /dev/null +++ b/ci/official/envs/enable_pycpp_build @@ -0,0 +1,20 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Changes the behavior in pycpp.sh from "run all tests" to "verify that all +# tests can compile." Used in some CI jobs (macOS and Linux Arm64) where test +# execution is too expensive. +TFCI_PYCPP_SWAP_TO_BUILD_ENABLE=1 +TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --build_tests_only" \ No newline at end of file diff --git a/ci/official/envs/linux_arm64 b/ci/official/envs/linux_arm64 index 7c4270408dd68d..161b0e2e803822 100644 --- a/ci/official/envs/linux_arm64 +++ b/ci/official/envs/linux_arm64 @@ -19,7 +19,7 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 # despite lacking Nvidia CUDA support. TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow" TFCI_DOCKER_ENABLE=1 -TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python +TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-2-16-multi-python TFCI_DOCKER_PULL_ENABLE=1 TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" TFCI_INDEX_HTML_ENABLE=1 diff --git a/ci/official/envs/linux_arm64_onednn b/ci/official/envs/linux_arm64_onednn new file mode 100644 index 00000000000000..0d4b7cbd03bbaa --- /dev/null +++ b/ci/official/envs/linux_arm64_onednn @@ -0,0 +1,16 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +source ci/official/envs/linux_arm64 +TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --test_env=TF_ENABLE_ONEDNN_OPTS=1" diff --git a/ci/official/envs/linux_x86 b/ci/official/envs/linux_x86 index 97fe9956f14ee1..2efc0fac466b00 100644 --- a/ci/official/envs/linux_x86 +++ b/ci/official/envs/linux_x86 @@ -16,7 +16,7 @@ TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --conf TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow_cpu" TFCI_DOCKER_ENABLE=1 -TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} +TFCI_DOCKER_IMAGE=tensorflow/build:2.16-python${TFCI_PYTHON_VERSION} TFCI_DOCKER_PULL_ENABLE=1 TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_INDEX_HTML_ENABLE=1 diff --git a/ci/official/envs/macos_x86 b/ci/official/envs/macos_x86 index 3959830535628b..56166a0d0d4309 100644 --- a/ci/official/envs/macos_x86 +++ b/ci/official/envs/macos_x86 @@ -22,8 +22,8 @@ TFCI_MACOS_BAZEL_TEST_DIR_PATH="/System/Volumes/Data/bazel_output" TFCI_MACOS_INSTALL_BAZELISK_ENABLE=1 TFCI_MACOS_INSTALL_BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" TFCI_MACOS_TWINE_INSTALL_ENABLE=1 -TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 TFCI_OUTPUT_DIR=build_output +TFCI_WHL_BAZEL_TEST_ENABLE=1 TFCI_WHL_SIZE_LIMIT=255M TFCI_WHL_SIZE_LIMIT_ENABLE=1 diff --git a/ci/official/envs/macos_x86_cross_compile b/ci/official/envs/macos_x86_cross_compile index 79f717156ea939..3a9dd2557faa1c 100644 --- a/ci/official/envs/macos_x86_cross_compile +++ b/ci/official/envs/macos_x86_cross_compile @@ -13,8 +13,7 @@ # limitations under the License. # ============================================================================== source ci/official/envs/macos_x86 -# Limit jobs to 100 to avoid running into "out of memory" issues (b/316266643) -TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --jobs=100 --config cross_compile_macos_x86" +TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_macos_x86" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_macos_x86 TFCI_MACOS_CROSS_COMPILE_ENABLE=1 TFCI_MACOS_CROSS_COMPILE_SDK_DEST="tensorflow/tools/toolchains/cross_compile/cc/MacOSX.sdk" diff --git a/ci/official/envs/multicache b/ci/official/envs/public_cache similarity index 85% rename from ci/official/envs/multicache rename to ci/official/envs/public_cache index eb5c58e68e646f..ec57aad869ca47 100644 --- a/ci/official/envs/multicache +++ b/ci/official/envs/public_cache @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Combine TF public build cache and local disk cache +# Sourcing this enables Bazel remote cache (public, read-only) # The cache configs are different for MacOS and Linux if [[ $(uname -s) == "Darwin" ]]; then - TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_macos_cache --disk_cache=$TFCI_OUTPUT_DIR/cache" + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_macos_cache" else - TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_cache --disk_cache=$TFCI_OUTPUT_DIR/cache" + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_cache" fi diff --git a/ci/official/envs/public_cache_push b/ci/official/envs/public_cache_push new file mode 100644 index 00000000000000..e686a0aac5d5ce --- /dev/null +++ b/ci/official/envs/public_cache_push @@ -0,0 +1,24 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Sourcing this enables Bazel remote cache (read and write) +# Note that "_push" cache configs write to GCS buckets and require +# authentication. If you are not a Googler, source "public_cache" to enable the +# public read-only cache. +# The cache configs are different for MacOS and Linux +if [[ $(uname -s) == "Darwin" ]]; then + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_macos_cache_push" +else + TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config tf_public_cache_push" +fi diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 34294fe8a107f6..cf346007949c1e 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -15,7 +15,11 @@ # ============================================================================== source "${BASH_SOURCE%/*}/utilities/setup.sh" -tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" +if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then + tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" +else + tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" +fi # Note: the profile can be viewed by visiting chrome://tracing in a Chrome browser. # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling diff --git a/ci/official/requirements_updater/.bazelversion b/ci/official/requirements_updater/.bazelversion new file mode 100644 index 00000000000000..f22d756da39d4c --- /dev/null +++ b/ci/official/requirements_updater/.bazelversion @@ -0,0 +1 @@ +6.5.0 diff --git a/ci/official/requirements_updater/release_updater.sh b/ci/official/requirements_updater/release_updater.sh index 88d54666eb21db..3d47199c7187af 100644 --- a/ci/official/requirements_updater/release_updater.sh +++ b/ci/official/requirements_updater/release_updater.sh @@ -25,7 +25,10 @@ SUPPORTED_VERSIONS=("3_9" "3_10" "3_11" "3_12") for VERSION in "${SUPPORTED_VERSIONS[@]}" do cp ../../../requirements_lock_"$VERSION".txt "requirements_lock_"$VERSION".txt" - bazel run --experimental_convenience_symlinks=ignore //:requirements_"$VERSION"_release.update + bazel run \ + --experimental_convenience_symlinks=ignore \ + --enable_bzlmod=false \ + //:requirements_"$VERSION"_release.update sed -i '/^#/d' requirements_lock_"$VERSION".txt mv "requirements_lock_"$VERSION".txt" ../../../requirements_lock_"$VERSION".txt done diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 46b5a532d5bb17..364134fcf7c39b 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -11,7 +11,7 @@ astor == 0.7.1 typing_extensions == 4.8.0 gast == 0.4.0 termcolor == 2.3.0 -wrapt == 1.14.1 +wrapt == 1.16.0 tblib == 2.0.0 # Install tensorboard, and keras @@ -19,7 +19,7 @@ tblib == 2.0.0 # Note that we must use nightly here as these are used in nightly jobs # For release jobs, we will pin these on the release branch keras-nightly ~= 3.0.0.dev -tb-nightly ~= 2.15.0.a +tb-nightly ~= 2.17.0.a # Test dependencies grpcio >= 1.24.3, < 2.0 diff --git a/ci/official/requirements_updater/updater.sh b/ci/official/requirements_updater/updater.sh index 898151dab1b599..95c67322966d11 100755 --- a/ci/official/requirements_updater/updater.sh +++ b/ci/official/requirements_updater/updater.sh @@ -28,7 +28,10 @@ SUPPORTED_VERSIONS=("3_9" "3_10" "3_11" "3_12") for VERSION in "${SUPPORTED_VERSIONS[@]}" do touch "requirements_lock_$VERSION.txt" - bazel run --experimental_convenience_symlinks=ignore //:requirements_"$VERSION".update + bazel run \ + --experimental_convenience_symlinks=ignore \ + --enable_bzlmod=false \ + //:requirements_"$VERSION".update sed -i '/^#/d' requirements_lock_"$VERSION".txt mv requirements_lock_"$VERSION".txt ../../../requirements_lock_"$VERSION".txt done diff --git a/ci/official/upload.sh b/ci/official/upload.sh index d3411cb2284876..28ad6a68409419 100755 --- a/ci/official/upload.sh +++ b/ci/official/upload.sh @@ -25,28 +25,36 @@ if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then fi source ci/official/utilities/get_versions.sh -# $TF_VER_FULL will resolve to e.g. "2.15.0-rc2". When given a directory path, -# gsutil copies the basename of the directory into the provided path. Uploading -# /tmp/$TF_VER_FULL to gs://bucket/ will create gs://bucket/$TF_VER_FULL/files. -# Since $TF_VER_FULL comes from get_versions.sh, which must be run *after* -# update_version.py, this can't be set inside the rest of the _upload envs. -DOWNLOADS="$(mktemp -d)/$TF_VER_FULL" +# Note on gsutil commands: +# "gsutil cp" always "copies into". It cannot act on the contents of a directory +# and it does not seem possible to e.g. copy "gs://foo/bar" as anything other than +# "/path/bar". This script uses "gsutil rsync" instead, which acts on directory +# contents. About arguments to gsutil: +# "gsutil -m rsync" runs in parallel. +# "gsutil rsync -r" is recursive and makes directories work. +# "gsutil rsync -d" is "sync and delete files from destination if not present in source" + +DOWNLOADS="$(mktemp -d)" mkdir -p "$DOWNLOADS" -# -r is needed to copy a whole folder. -gsutil -m cp -r "$TFCI_ARTIFACT_STAGING_GCS_URI" "$DOWNLOADS" +gsutil -m rsync -r "$TFCI_ARTIFACT_STAGING_GCS_URI" "$DOWNLOADS" ls "$DOWNLOADS" # Upload all build artifacts to e.g. gs://tensorflow/versions/2.16.0-rc1 (releases) or # gs://tensorflow/nightly/2.16.0-dev20240105 (nightly), overwriting previous values. if [[ "$TFCI_ARTIFACT_FINAL_GCS_ENABLE" == 1 ]]; then gcloud auth activate-service-account --key-file="$TFCI_ARTIFACT_FINAL_GCS_SA_PATH" - gsutil -m cp -r "$DOWNLOADS" "$TFCI_ARTIFACT_FINAL_GCS_URI" + + # $TF_VER_FULL will resolve to e.g. "2.15.0-rc2". Since $TF_VER_FULL comes + # from get_versions.sh, which must be run *after* update_version.py, FINAL_URI + # can't be set inside the rest of the _upload envs. + FINAL_URI="$TFCI_ARTIFACT_FINAL_GCS_URI/$TF_VER_FULL" + gsutil -m rsync -d -r "$DOWNLOADS" "$FINAL_URI" + # Also mirror the latest-uploaded folder to the "latest" directory. - # GCS does not support symlinks. -p preserves ACLs. -d deletes - # no-longer-present files (it's what makes this act as a mirror). - gsutil rsync -d -p -r "$TFCI_ARTIFACT_FINAL_GCS_URI" "$TFCI_ARTIFACT_LATEST_GCS_URI" + # GCS does not support symlinks. + gsutil -m rsync -d -r "$FINAL_URI" "$TFCI_ARTIFACT_LATEST_GCS_URI" fi if [[ "$TFCI_ARTIFACT_FINAL_PYPI_ENABLE" == 1 ]]; then - twine upload $TFCI_UPLOAD_WHL_PYPI_ARGS "$DOWNLOADS"/*.whl + twine upload $TFCI_ARTIFACT_FINAL_PYPI_ARGS "$DOWNLOADS"/*.whl fi diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index 78dd88f1d56be6..8dacee0875535d 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -306,6 +306,12 @@ EOF echo "Look at the instructions for ':api_compatibility_test -- --update_goldens=True'" } +# See b/279852433 (internal). +# TODO(b/279852433) Replace deps(//tensorflow/...) with deps(//...) +@test "Verify that it's possible to query every TensorFlow target without BUILD errors" { + bazel query "deps(//tensorflow/...)" > /dev/null +} + teardown_file() { bazel shutdown } diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index 5a03886c96eab9..10e847d0679391 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -88,18 +88,6 @@ if [[ "${OSTYPE}" =~ darwin* ]]; then source ./ci/official/utilities/setup_macos.sh fi -# Force-disable uploads if the job initiator is not louhi-bridge-server -# (the user that triggers all of the nightly and release jobs on Louhi) -# This is temporary: it's currently standard practice for employees to -# run nightly jobs for testing purposes. We're aiming to move away from -# this with more convenient methods, but as long as it's possible to do, -# we want to make sure those extra jobs don't upload anything. -# TODO(angerson) Remove this once artifact staging is done; after that, -# simply running a nightly again will not risk upload anything. -if [[ "${KOKORO_BUILD_INITIATOR:-}" != "louhi-bridge-server" ]]; then - source ./ci/official/envs/no_upload -fi - # Create and expand to the full path of TFCI_OUTPUT_DIR export TFCI_OUTPUT_DIR=$(realpath "$TFCI_OUTPUT_DIR") mkdir -p "$TFCI_OUTPUT_DIR" diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh index c50ea618cfea6c..36afa2545eb244 100755 --- a/ci/official/utilities/setup_docker.sh +++ b/ci/official/utilities/setup_docker.sh @@ -14,8 +14,13 @@ # limitations under the License. # ============================================================================== if [[ "$TFCI_DOCKER_PULL_ENABLE" == 1 ]]; then + # Simple retry logic for docker-pull errors. Sleeps for 15s if a pull fails. + # Pulling an already-pulled container image will finish instantly, so + # repeating the command costs nothing. + docker pull "$TFCI_DOCKER_IMAGE" || sleep 15 + docker pull "$TFCI_DOCKER_IMAGE" || sleep 15 docker pull "$TFCI_DOCKER_IMAGE" -fi +fi if [[ "$TFCI_DOCKER_REBUILD_ENABLE" == 1 ]]; then DOCKER_BUILDKIT=1 docker build --cache-from "$TFCI_DOCKER_IMAGE" -t "$TFCI_DOCKER_IMAGE" $TFCI_DOCKER_REBUILD_ARGS diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index f5b07d565437e8..8a63d318c6e18e 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -74,19 +74,13 @@ fi # it. TFCI Mac VMs only have one Python version installed so we need to install # the other versions manually. if [[ "${TFCI_MACOS_PYENV_INSTALL_ENABLE}" == 1 ]]; then - pyenv install "$TFCI_PYTHON_VERSION" + # Install the necessary Python, unless it's already present + pyenv install -s "$TFCI_PYTHON_VERSION" pyenv local "$TFCI_PYTHON_VERSION" # Do a sanity check to make sure that we using the correct Python version python --version fi -if [[ "$TFCI_PYTHON_VERSION" == "3.12" ]]; then - # dm-tree (Keras v3 dependency) doesn't have pre-built wheels for 3.12 yet. - # Having CMake allows building them. - # Once the wheels are added, this should be removed - b/308399490. - brew install cmake -fi - # TFCI Mac VM images do not have twine installed by default so we need to # install it manually. We use Twine in nightly builds to upload Python packages # to PyPI. diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index ac8c282389920f..9b8c0a6850d6e6 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -28,7 +28,7 @@ if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then fi tfrun bazel build $TFCI_BAZEL_COMMON_ARGS //tensorflow/tools/pip_package/v2:wheel $TFCI_BUILD_PIP_PACKAGE_ARGS -tfrun cp -a "./bazel-bin/tensorflow/tools/pip_package/v2/wheel_house/." "$TFCI_OUTPUT_DIR" +tfrun find ./bazel-bin/tensorflow/tools/pip_package -iname "*.whl" -exec cp {} $TFCI_OUTPUT_DIR \; tfrun ./ci/official/utilities/rename_and_verify_wheels.sh if [[ "$TFCI_ARTIFACT_STAGING_GCS_ENABLE" == 1 ]]; then diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 434b38d603df80..2335d295d0faf6 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,113 +249,115 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 @@ -407,10 +400,6 @@ numpy==1.23.5 ; python_version <= "3.11" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -440,57 +429,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -526,8 +494,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -542,13 +510,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -560,81 +532,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in setuptools==68.2.2 \ diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 434b38d603df80..2335d295d0faf6 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,113 +249,115 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 @@ -407,10 +400,6 @@ numpy==1.23.5 ; python_version <= "3.11" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -440,57 +429,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -526,8 +494,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -542,13 +510,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -560,81 +532,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in setuptools==68.2.2 \ diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 4697b1849dc273..9bc6eff7313ec3 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,150 +249,156 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 # via keras-nightly -numpy==1.26.1 ; python_version >= "3.12" \ - --hash=sha256:06934e1a22c54636a059215d6da99e23286424f316fddd979f5071093b648668 \ - --hash=sha256:1c59c046c31a43310ad0199d6299e59f57a289e22f0f36951ced1c9eac3665b9 \ - --hash=sha256:1d1bd82d539607951cac963388534da3b7ea0e18b149a53cf883d8f699178c0f \ - --hash=sha256:1e11668d6f756ca5ef534b5be8653d16c5352cbb210a5c2a79ff288e937010d5 \ - --hash=sha256:3649d566e2fc067597125428db15d60eb42a4e0897fc48d28cb75dc2e0454e53 \ - --hash=sha256:59227c981d43425ca5e5c01094d59eb14e8772ce6975d4b2fc1e106a833d5ae2 \ - --hash=sha256:6081aed64714a18c72b168a9276095ef9155dd7888b9e74b5987808f0dd0a974 \ - --hash=sha256:6965888d65d2848e8768824ca8288db0a81263c1efccec881cb35a0d805fcd2f \ - --hash=sha256:76ff661a867d9272cd2a99eed002470f46dbe0943a5ffd140f49be84f68ffc42 \ - --hash=sha256:78ca54b2f9daffa5f323f34cdf21e1d9779a54073f0018a3094ab907938331a2 \ - --hash=sha256:82e871307a6331b5f09efda3c22e03c095d957f04bf6bc1804f30048d0e5e7af \ - --hash=sha256:8ab9163ca8aeb7fd32fe93866490654d2f7dda4e61bc6297bf72ce07fdc02f67 \ - --hash=sha256:9696aa2e35cc41e398a6d42d147cf326f8f9d81befcb399bc1ed7ffea339b64e \ - --hash=sha256:97e5d6a9f0702c2863aaabf19f0d1b6c2628fbe476438ce0b5ce06e83085064c \ - --hash=sha256:9f42284ebf91bdf32fafac29d29d4c07e5e9d1af862ea73686581773ef9e73a7 \ - --hash=sha256:a03fb25610ef560a6201ff06df4f8105292ba56e7cdd196ea350d123fc32e24e \ - --hash=sha256:a5b411040beead47a228bde3b2241100454a6abde9df139ed087bd73fc0a4908 \ - --hash=sha256:af22f3d8e228d84d1c0c44c1fbdeb80f97a15a0abe4f080960393a00db733b66 \ - --hash=sha256:afd5ced4e5a96dac6725daeb5242a35494243f2239244fad10a90ce58b071d24 \ - --hash=sha256:b9d45d1dbb9de84894cc50efece5b09939752a2d75aab3a8b0cef6f3a35ecd6b \ - --hash=sha256:bb894accfd16b867d8643fc2ba6c8617c78ba2828051e9a69511644ce86ce83e \ - --hash=sha256:c8c6c72d4a9f831f328efb1312642a1cafafaa88981d9ab76368d50d07d93cbe \ - --hash=sha256:cd7837b2b734ca72959a1caf3309457a318c934abef7a43a14bb984e574bbb9a \ - --hash=sha256:cdd9ec98f0063d93baeb01aad472a1a0840dee302842a2746a7a8e92968f9575 \ - --hash=sha256:d1cfc92db6af1fd37a7bb58e55c8383b4aa1ba23d012bdbba26b4bcca45ac297 \ - --hash=sha256:d1d2c6b7dd618c41e202c59c1413ef9b2c8e8a15f5039e344af64195459e3104 \ - --hash=sha256:d2984cb6caaf05294b8466966627e80bf6c7afd273279077679cb010acb0e5ab \ - --hash=sha256:d58e8c51a7cf43090d124d5073bc29ab2755822181fcad978b12e144e5e5a4b3 \ - --hash=sha256:d78f269e0c4fd365fc2992c00353e4530d274ba68f15e968d8bc3c69ce5f5244 \ - --hash=sha256:dcfaf015b79d1f9f9c9fd0731a907407dc3e45769262d657d754c3a028586124 \ - --hash=sha256:e44ccb93f30c75dfc0c3aa3ce38f33486a75ec9abadabd4e59f114994a9c4617 \ - --hash=sha256:e509cbc488c735b43b5ffea175235cec24bbc57b227ef1acc691725beb230d1c +numpy==1.26.4 ; python_version >= "3.12" \ + --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ + --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ + --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ + --hash=sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0 \ + --hash=sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010 \ + --hash=sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a \ + --hash=sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea \ + --hash=sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c \ + --hash=sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71 \ + --hash=sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110 \ + --hash=sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be \ + --hash=sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a \ + --hash=sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a \ + --hash=sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5 \ + --hash=sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed \ + --hash=sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd \ + --hash=sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c \ + --hash=sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e \ + --hash=sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0 \ + --hash=sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c \ + --hash=sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a \ + --hash=sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b \ + --hash=sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0 \ + --hash=sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6 \ + --hash=sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2 \ + --hash=sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a \ + --hash=sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30 \ + --hash=sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218 \ + --hash=sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5 \ + --hash=sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07 \ + --hash=sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2 \ + --hash=sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4 \ + --hash=sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764 \ + --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ + --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ + --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f # via # -r requirements.in # h5py @@ -411,10 +408,6 @@ numpy==1.26.1 ; python_version >= "3.12" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -444,57 +437,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -530,8 +502,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -546,13 +518,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -564,81 +540,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in setuptools==68.2.2 \ diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 0fb35480a8f886..9d9e85aceda9c7 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -1,6 +1,6 @@ -absl-py==2.0.0 \ - --hash=sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3 \ - --hash=sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via # keras-nightly # tb-nightly @@ -12,105 +12,101 @@ astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 # via -r requirements.in -cachetools==5.3.2 \ - --hash=sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2 \ - --hash=sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1 - # via google-auth -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 # via requests -charset-normalizer==3.3.1 \ - --hash=sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5 \ - --hash=sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93 \ - --hash=sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a \ - --hash=sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d \ - --hash=sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c \ - --hash=sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1 \ - --hash=sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58 \ - --hash=sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2 \ - --hash=sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557 \ - --hash=sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147 \ - --hash=sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041 \ - --hash=sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2 \ - --hash=sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2 \ - --hash=sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7 \ - --hash=sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296 \ - --hash=sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690 \ - --hash=sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67 \ - --hash=sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57 \ - --hash=sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597 \ - --hash=sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846 \ - --hash=sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b \ - --hash=sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97 \ - --hash=sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c \ - --hash=sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62 \ - --hash=sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa \ - --hash=sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f \ - --hash=sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e \ - --hash=sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821 \ - --hash=sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3 \ - --hash=sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4 \ - --hash=sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb \ - --hash=sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727 \ - --hash=sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514 \ - --hash=sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d \ - --hash=sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761 \ - --hash=sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55 \ - --hash=sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f \ - --hash=sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c \ - --hash=sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034 \ - --hash=sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6 \ - --hash=sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae \ - --hash=sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1 \ - --hash=sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14 \ - --hash=sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1 \ - --hash=sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228 \ - --hash=sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708 \ - --hash=sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48 \ - --hash=sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f \ - --hash=sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5 \ - --hash=sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f \ - --hash=sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4 \ - --hash=sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8 \ - --hash=sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff \ - --hash=sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61 \ - --hash=sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b \ - --hash=sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97 \ - --hash=sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b \ - --hash=sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605 \ - --hash=sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728 \ - --hash=sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d \ - --hash=sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c \ - --hash=sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf \ - --hash=sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673 \ - --hash=sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1 \ - --hash=sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b \ - --hash=sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41 \ - --hash=sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8 \ - --hash=sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f \ - --hash=sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4 \ - --hash=sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008 \ - --hash=sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9 \ - --hash=sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5 \ - --hash=sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f \ - --hash=sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e \ - --hash=sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273 \ - --hash=sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45 \ - --hash=sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e \ - --hash=sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656 \ - --hash=sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e \ - --hash=sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c \ - --hash=sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2 \ - --hash=sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72 \ - --hash=sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056 \ - --hash=sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397 \ - --hash=sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42 \ - --hash=sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd \ - --hash=sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3 \ - --hash=sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213 \ - --hash=sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf \ - --hash=sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67 +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ @@ -133,6 +129,7 @@ dm-tree==0.1.8 \ --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b \ --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ @@ -140,6 +137,8 @@ dm-tree==0.1.8 \ --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22 \ + --hash=sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715 \ --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ @@ -147,6 +146,7 @@ dm-tree==0.1.8 \ --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e \ --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ @@ -154,6 +154,7 @@ dm-tree==0.1.8 \ --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly @@ -161,71 +162,61 @@ gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r requirements.in -google-auth==2.23.3 \ - --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ - --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.1.0 \ - --hash=sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12 \ - --hash=sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb - # via tb-nightly -grpcio==1.59.2 \ - --hash=sha256:023088764012411affe7db183d1ada3ad9daf2e23ddc719ff46d7061de661340 \ - --hash=sha256:08d77e682f2bf730a4961eea330e56d2f423c6a9b91ca222e5b1eb24a357b19f \ - --hash=sha256:0a4a3833c0e067f3558538727235cd8a49709bff1003200bbdefa2f09334e4b1 \ - --hash=sha256:0a754aff9e3af63bdc4c75c234b86b9d14e14a28a30c4e324aed1a9b873d755f \ - --hash=sha256:11168ef43e4a43ff1b1a65859f3e0ef1a173e277349e7fb16923ff108160a8cd \ - --hash=sha256:128e20f57c5f27cb0157e73756d1586b83c1b513ebecc83ea0ac37e4b0e4e758 \ - --hash=sha256:1f9524d1d701e399462d2c90ba7c193e49d1711cf429c0d3d97c966856e03d00 \ - --hash=sha256:1ff16d68bf453275466a9a46739061a63584d92f18a0f5b33d19fc97eb69867c \ - --hash=sha256:2067274c88bc6de89c278a672a652b4247d088811ece781a4858b09bdf8448e3 \ - --hash=sha256:2171c39f355ba5b551c5d5928d65aa6c69807fae195b86ef4a7d125bcdb860a9 \ - --hash=sha256:242adc47725b9a499ee77c6a2e36688fa6c96484611f33b1be4c57ab075a92dd \ - --hash=sha256:27f879ae604a7fcf371e59fba6f3ff4635a4c2a64768bd83ff0cac503142fef4 \ - --hash=sha256:2b230028a008ae1d0f430acb227d323ff8a619017415cf334c38b457f814119f \ - --hash=sha256:3059668df17627f0e0fa680e9ef8c995c946c792612e9518f5cc1503be14e90b \ - --hash=sha256:31176aa88f36020055ace9adff2405a33c8bdbfa72a9c4980e25d91b2f196873 \ - --hash=sha256:36f53c2b3449c015880e7d55a89c992c357f176327b0d2873cdaaf9628a37c69 \ - --hash=sha256:3b4368b33908f683a363f376dfb747d40af3463a6e5044afee07cf9436addf96 \ - --hash=sha256:3c61d641d4f409c5ae46bfdd89ea42ce5ea233dcf69e74ce9ba32b503c727e29 \ - --hash=sha256:4abb717e320e74959517dc8e84a9f48fbe90e9abe19c248541e9418b1ce60acd \ - --hash=sha256:4c93f4abbb54321ee6471e04a00139c80c754eda51064187963ddf98f5cf36a4 \ - --hash=sha256:535561990e075fa6bd4b16c4c3c1096b9581b7bb35d96fac4650f1181e428268 \ - --hash=sha256:53c9aa5ddd6857c0a1cd0287225a2a25873a8e09727c2e95c4aebb1be83a766a \ - --hash=sha256:5d573e70a6fe77555fb6143c12d3a7d3fa306632a3034b4e7c59ca09721546f8 \ - --hash=sha256:6009386a2df66159f64ac9f20425ae25229b29b9dd0e1d3dd60043f037e2ad7e \ - --hash=sha256:686e975a5d16602dc0982c7c703948d17184bd1397e16c8ee03511ecb8c4cdda \ - --hash=sha256:6959fb07e8351e20501ffb8cc4074c39a0b7ef123e1c850a7f8f3afdc3a3da01 \ - --hash=sha256:6b25ed37c27e652db01be341af93fbcea03d296c024d8a0e680017a268eb85dd \ - --hash=sha256:6da6dea3a1bacf99b3c2187e296db9a83029ed9c38fd4c52b7c9b7326d13c828 \ - --hash=sha256:72ca2399097c0b758198f2ff30f7178d680de8a5cfcf3d9b73a63cf87455532e \ - --hash=sha256:73abb8584b0cf74d37f5ef61c10722adc7275502ab71789a8fe3cb7ef04cf6e2 \ - --hash=sha256:74100fecaec8a535e380cf5f2fb556ff84957d481c13e54051c52e5baac70541 \ - --hash=sha256:75c6ecb70e809cf1504465174343113f51f24bc61e22a80ae1c859f3f7034c6d \ - --hash=sha256:7cf05053242f61ba94014dd3a986e11a083400a32664058f80bf4cf817c0b3a1 \ - --hash=sha256:9411e24328a2302e279e70cae6e479f1fddde79629fcb14e03e6d94b3956eabf \ - --hash=sha256:a213acfbf186b9f35803b52e4ca9addb153fc0b67f82a48f961be7000ecf6721 \ - --hash=sha256:bb7e0fe6ad73b7f06d7e2b689c19a71cf5cc48f0c2bf8608469e51ffe0bd2867 \ - --hash=sha256:c2504eed520958a5b77cc99458297cb7906308cb92327f35fb7fbbad4e9b2188 \ - --hash=sha256:c35aa9657f5d5116d23b934568e0956bd50c615127810fffe3ac356a914c176a \ - --hash=sha256:c5f09cffa619adfb44799fa4a81c2a1ad77c887187613fb0a8f201ab38d89ba1 \ - --hash=sha256:c978f864b35f2261e0819f5cd88b9830b04dc51bcf055aac3c601e525a10d2ba \ - --hash=sha256:cbe946b3e6e60a7b4618f091e62a029cb082b109a9d6b53962dd305087c6e4fd \ - --hash=sha256:cc3e4cd087f07758b16bef8f31d88dbb1b5da5671d2f03685ab52dece3d7a16e \ - --hash=sha256:cf0dead5a2c5a3347af2cfec7131d4f2a2e03c934af28989c9078f8241a491fa \ - --hash=sha256:d2794f0e68b3085d99b4f6ff9c089f6fdd02b32b9d3efdfbb55beac1bf22d516 \ - --hash=sha256:d2fa68a96a30dd240be80bbad838a0ac81a61770611ff7952b889485970c4c71 \ - --hash=sha256:d6f70406695e3220f09cd7a2f879333279d91aa4a8a1d34303b56d61a8180137 \ - --hash=sha256:d8f9cd4ad1be90b0cf350a2f04a38a36e44a026cac1e036ac593dc48efe91d52 \ - --hash=sha256:da2d94c15f88cd40d7e67f7919d4f60110d2b9d5b1e08cf354c2be773ab13479 \ - --hash=sha256:e1727c1c0e394096bb9af185c6923e8ea55a5095b8af44f06903bcc0e06800a2 \ - --hash=sha256:e420ced29b5904cdf9ee5545e23f9406189d8acb6750916c2db4793dada065c6 \ - --hash=sha256:e82c5cf1495244adf5252f925ac5932e5fd288b3e5ab6b70bec5593074b7236c \ - --hash=sha256:f1ef0d39bc1feb420caf549b3c657c871cad4ebbcf0580c4d03816b0590de0cf \ - --hash=sha256:f8753a6c88d1d0ba64302309eecf20f70d2770f65ca02d83c2452279085bfcd3 \ - --hash=sha256:f93dbf58f03146164048be5426ffde298b237a5e059144847e4940f5b80172c3 +grpcio==1.60.1 \ + --hash=sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8 \ + --hash=sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8 \ + --hash=sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104 \ + --hash=sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc \ + --hash=sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092 \ + --hash=sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1 \ + --hash=sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180 \ + --hash=sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05 \ + --hash=sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287 \ + --hash=sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216 \ + --hash=sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c \ + --hash=sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d \ + --hash=sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594 \ + --hash=sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c \ + --hash=sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87 \ + --hash=sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de \ + --hash=sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9 \ + --hash=sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2 \ + --hash=sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff \ + --hash=sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904 \ + --hash=sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73 \ + --hash=sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549 \ + --hash=sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f \ + --hash=sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c \ + --hash=sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0 \ + --hash=sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e \ + --hash=sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367 \ + --hash=sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc \ + --hash=sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303 \ + --hash=sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb \ + --hash=sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c \ + --hash=sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804 \ + --hash=sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03 \ + --hash=sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6 \ + --hash=sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f \ + --hash=sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0 \ + --hash=sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd \ + --hash=sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21 \ + --hash=sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858 \ + --hash=sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce \ + --hash=sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9 \ + --hash=sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23 \ + --hash=sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073 \ + --hash=sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2 \ + --hash=sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872 \ + --hash=sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe \ + --hash=sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7 \ + --hash=sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6 \ + --hash=sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929 \ + --hash=sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962 \ + --hash=sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a \ + --hash=sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a \ + --hash=sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177 \ + --hash=sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525 # via # -r requirements.in # tb-nightly @@ -258,117 +249,119 @@ h5py==3.10.0 \ # via # -r requirements.in # keras-nightly -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.6 \ + --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ + --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f # via requests -importlib-metadata==6.8.0 \ - --hash=sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb \ - --hash=sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743 +importlib-metadata==7.0.1 \ + --hash=sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e \ + --hash=sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc # via markdown jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r requirements.in -keras-nightly==3.0.0.dev2023103103 \ - --hash=sha256:25a6a9030d7e067d535ec41ae6700636c90cabada80fbddb3bb8775006807860 \ - --hash=sha256:c39fccebb3d4cc9d838371981c4d3eef08fc06eadf6cffb39b356cb625bab50f +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r requirements.in -lit==17.0.4 \ - --hash=sha256:ee2e180128e770abc6aed3a02de2daf09d81b7d30225e315205d3599c311d304 +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r requirements.in -markdown==3.5 \ - --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ - --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 +markdown==3.5.2 \ + --hash=sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd \ + --hash=sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8 # via tb-nightly markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ - --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ - --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 # via werkzeug mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.3.1 \ - --hash=sha256:3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 \ - --hash=sha256:42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 \ - --hash=sha256:438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 \ - --hash=sha256:4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 \ - --hash=sha256:4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 \ - --hash=sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 \ - --hash=sha256:5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be \ - --hash=sha256:5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 \ - --hash=sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611 \ - --hash=sha256:70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b \ - --hash=sha256:979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 \ - --hash=sha256:a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 \ - --hash=sha256:cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 \ - --hash=sha256:d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f \ - --hash=sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 \ - --hash=sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b \ - --hash=sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e - # via jax +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # jax + # keras-nightly namex==0.0.7 \ --hash=sha256:84ba65bc4d22bd909e3d26bf2ffb4b9529b608cb3f9a4336f776b04204ced69b \ --hash=sha256:8a4f062945f405d77cb66b907f16aa2fd83681945e998be840eb6c4154d40108 @@ -411,10 +404,6 @@ numpy==1.23.5 ; python_version <= "3.11" \ # opt-einsum # scipy # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 @@ -444,57 +433,36 @@ protobuf==4.23.4 \ --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly -psutil==5.9.6 \ - --hash=sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28 \ - --hash=sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017 \ - --hash=sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602 \ - --hash=sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac \ - --hash=sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a \ - --hash=sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9 \ - --hash=sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4 \ - --hash=sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c \ - --hash=sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c \ - --hash=sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c \ - --hash=sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a \ - --hash=sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c \ - --hash=sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57 \ - --hash=sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a \ - --hash=sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d \ - --hash=sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -pygments==2.16.1 \ - --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ - --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rich==13.6.0 \ - --hash=sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245 \ - --hash=sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef + # via -r requirements.in +rich==13.7.0 \ + --hash=sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa \ + --hash=sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235 # via keras-nightly -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth scipy==1.11.3 \ --hash=sha256:00f325434b6424952fbb636506f0567898dca7b0f7654d48f1c382ea338ce9a3 \ --hash=sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929 \ @@ -530,8 +498,8 @@ six==1.16.0 \ # via # astunparse # tb-nightly -tb-nightly==2.15.0a20231023 \ - --hash=sha256:8990a52985e3296aa18a6825efc017bcded9e2fb2cbcdd2d01f5c62d4fdc9825 +tb-nightly==2.17.0a20240214 \ + --hash=sha256:dbd59bcfd9b028e6199050e36cbb4ee1db4a94e69a7f8c00865f1498112a2b83 # via -r requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -546,13 +514,17 @@ termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a # via -r requirements.in +tf-keras-nightly==2.16.0.dev2024021410 \ + --hash=sha256:05a4c19c6a795ec9ed6f4dac69d443601b9206b910524d8b19bad6e362d49f47 \ + --hash=sha256:83d5a1dd3979b42164a2fac4bdd2ed5981ef0c2a0514ca64e4bc67e369caef5c + # via tb-nightly typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef # via -r requirements.in -urllib3==2.0.7 \ - --hash=sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84 \ - --hash=sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e +urllib3==2.2.0 \ + --hash=sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20 \ + --hash=sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224 # via requests werkzeug==3.0.1 \ --hash=sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc \ @@ -564,81 +536,77 @@ wheel==0.41.3 \ # via # -r requirements.in # astunparse -wrapt==1.14.1 \ - --hash=sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3 \ - --hash=sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b \ - --hash=sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4 \ - --hash=sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2 \ - --hash=sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656 \ - --hash=sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3 \ - --hash=sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9 \ - --hash=sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff \ - --hash=sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9 \ - --hash=sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310 \ - --hash=sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224 \ - --hash=sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a \ - --hash=sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57 \ - --hash=sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069 \ - --hash=sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335 \ - --hash=sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383 \ - --hash=sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe \ - --hash=sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204 \ - --hash=sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87 \ - --hash=sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d \ - --hash=sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b \ - --hash=sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907 \ - --hash=sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be \ - --hash=sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f \ - --hash=sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0 \ - --hash=sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28 \ - --hash=sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1 \ - --hash=sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853 \ - --hash=sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc \ - --hash=sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf \ - --hash=sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3 \ - --hash=sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3 \ - --hash=sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164 \ - --hash=sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1 \ - --hash=sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c \ - --hash=sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1 \ - --hash=sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7 \ - --hash=sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1 \ - --hash=sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320 \ - --hash=sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed \ - --hash=sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1 \ - --hash=sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248 \ - --hash=sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c \ - --hash=sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456 \ - --hash=sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77 \ - --hash=sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef \ - --hash=sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1 \ - --hash=sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7 \ - --hash=sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86 \ - --hash=sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4 \ - --hash=sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d \ - --hash=sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d \ - --hash=sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8 \ - --hash=sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8 \ - --hash=sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5 \ - --hash=sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a \ - --hash=sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471 \ - --hash=sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00 \ - --hash=sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68 \ - --hash=sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3 \ - --hash=sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d \ - --hash=sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735 \ - --hash=sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d \ - --hash=sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569 \ - --hash=sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7 \ - --hash=sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59 \ - --hash=sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5 \ - --hash=sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb \ - --hash=sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b \ - --hash=sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f \ - --hash=sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55 \ - --hash=sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462 \ - --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ - --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 # via -r requirements.in zipp==3.17.0 \ --hash=sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31 \ diff --git a/tensorflow/BUILD b/tensorflow/BUILD index acc8468d6168e4..9eb036f01e0614 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -44,7 +44,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") # # buildifier: disable=out-of-order-load # load("//devtools/build_cleaner/skylark:action_config_test.bzl", "action_config_test") # load("//devtools/copybara/rules:copybara.bzl", "copybara_config_test") -# load("//tools/build_defs/license:license.bzl", "license") +# load("@rules_license//rules:license.bzl", "license") # # buildifier: enable=out-of-order-load # copybara:uncomment_end @@ -1631,6 +1631,7 @@ genrule( d="$${d#*external/farmhash_archive/src}" d="$${d#*external/$${extname}/}" + d="$${d#_virtual_includes/*/}" fi mkdir -p "$@/$${d}" diff --git a/tensorflow/c/eager/abstract_tensor_handle.cc b/tensorflow/c/eager/abstract_tensor_handle.cc index 8a4438e2b9e75a..e04a9810638f61 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.cc +++ b/tensorflow/c/eager/abstract_tensor_handle.cc @@ -34,7 +34,7 @@ std::string AbstractTensorHandle::DebugString() const { Status AbstractTensorHandle::TensorHandleStatus() const { // Tensor handles in current runtime don't carry error info and this method // should always return OK status. - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 29301e4e37f754..2d79853d7988ec 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -555,7 +555,7 @@ class CAPICustomDeviceTensorHandle } summary = std::string(reinterpret_cast(summary_buffer->data), summary_buffer->length); - return OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index 13b688889a4567..3cb7a5d0fa5f1a 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -315,11 +315,11 @@ class GraphErrorInjectionPass : public tensorflow::GraphOptimizationPass { tensorflow::Status Run( const tensorflow::GraphOptimizationPassOptions& options) override { if (!enabled_) { - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } if (first_call_) { first_call_ = false; - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } return tensorflow::errors::Internal("Graph pass runs for more than once!"); } @@ -447,7 +447,7 @@ class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { return tensorflow::errors::Internal("Injected graph pass error."); } } - return ::tensorflow::OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index d52e938c047a6d..c2f8125ddbb76a 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -606,7 +606,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op, TF_Status* status) { tensorflow::unwrap(op)->SetCancellationManager( tensorflow::unwrap(cancellation_manager)); - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue, @@ -667,7 +667,7 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, buf->data_deallocator = [](void* data, size_t length) { tensorflow::port::Free(data); }; - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TFE_ContextGetGraphDebugInfo(TFE_Context* ctx, const char* function_name, @@ -691,7 +691,7 @@ void TFE_ContextGetGraphDebugInfo(TFE_Context* ctx, const char* function_name, buf->data_deallocator = [](void* data, size_t length) { tensorflow::port::Free(data); }; - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, @@ -817,7 +817,7 @@ void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf, buf->data_deallocator = [](void* data, size_t length) { tensorflow::port::Free(data); }; - status->status = ::tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TFE_SetLogicalCpuDevices(TFE_Context* ctx, int num_cpus, @@ -960,7 +960,7 @@ void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states, *state_iter = std::move(s); ++state_iter; } - status->status = tensorflow::OkStatus(); + status->status = absl::OkStatus(); } void TFE_WaitAtBarrier(TFE_Context* ctx, const char* barrier_id, diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 53f340ee2aa450..8422459c21b529 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -52,7 +52,7 @@ Status SetDefaultTracingEngine(const char* name) { auto entry = GetFactories().find(name); if (entry != GetFactories().end()) { default_factory = GetFactories().find(name)->second; - return OkStatus(); + return absl::OkStatus(); } string msg = absl::StrCat( "No tracing engine factory has been registered with the key '", name, diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 5e804dca267a0d..0c9d4830850bb7 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -71,7 +71,7 @@ class GraphTensor : public TracingTensorHandle { DCHECK_GE(num_dims, -1); TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); if (num_dims == kUnknownRank) { - return OkStatus(); + return absl::OkStatus(); } std::vector dims(num_dims, kUnknownDim); @@ -81,7 +81,7 @@ class GraphTensor : public TracingTensorHandle { TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape)); - return OkStatus(); + return absl::OkStatus(); } tensorflow::FullTypeDef FullType() const override { @@ -119,7 +119,7 @@ class GraphOperation : public TracingOperation { device_name_ = raw_device_name; } op_type_ = op; - return OkStatus(); + return absl::OkStatus(); } Status SetOpName(const char* const op_name) override { if (op_) { @@ -135,7 +135,7 @@ class GraphOperation : public TracingOperation { mutex_lock l(g_->mu); op_.reset(new TF_OperationDescription(g_, op_type_.c_str(), g_->graph.NewName(op_name).c_str())); - return OkStatus(); + return absl::OkStatus(); } const string& Name() const override { return op_type_; } const string& DeviceName() const override { return device_name_; } @@ -143,7 +143,7 @@ class GraphOperation : public TracingOperation { Status SetDeviceName(const char* name) override { // TODO(srbs): Implement this. device_name_ = name; - return OkStatus(); + return absl::OkStatus(); } Status AddInput(AbstractTensorHandle* input) override { @@ -153,7 +153,7 @@ class GraphOperation : public TracingOperation { "Unable to cast input to GraphTensor"); } TF_AddInput(op_.get(), t->output_); - return OkStatus(); + return absl::OkStatus(); } Status AddInputList(absl::Span inputs) override { std::vector tf_outputs(inputs.size()); @@ -166,7 +166,7 @@ class GraphOperation : public TracingOperation { tf_outputs[i] = t->output_; } TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size()); - return OkStatus(); + return absl::OkStatus(); } Status Execute(absl::Span retvals, int* num_retvals) override { @@ -182,26 +182,26 @@ class GraphOperation : public TracingOperation { for (int i = 0; i < *num_retvals; ++i) { retvals[i] = new GraphTensor({operation, i}, g_); } - return OkStatus(); + return absl::OkStatus(); } Status SetAttrString(const char* attr_name, const char* data, size_t length) override { tensorflow::StringPiece s(data, length); op_->node_builder.Attr(attr_name, s); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrInt(const char* attr_name, int64_t value) override { op_->node_builder.Attr(attr_name, static_cast(value)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFloat(const char* attr_name, float value) override { op_->node_builder.Attr(attr_name, value); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrBool(const char* attr_name, bool value) override { op_->node_builder.Attr(attr_name, value); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrType(const char* const attr_name, DataType value) override { if (!op_) { @@ -210,7 +210,7 @@ class GraphOperation : public TracingOperation { "op_type and op_name must be specified before specifying attrs."); } op_->node_builder.Attr(attr_name, value); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) override { @@ -220,7 +220,7 @@ class GraphOperation : public TracingOperation { reinterpret_cast(dims), num_dims)); } op_->node_builder.Attr(attr_name, shape); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFunction(const char* attr_name, const AbstractOperation* value) override { @@ -232,7 +232,7 @@ class GraphOperation : public TracingOperation { tensorflow::NameAttrList func_name; func_name.set_name(string(value, value + length)); op_->node_builder.Attr(attr_name, func_name); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrTensor(const char* attr_name, AbstractTensorInterface* tensor) override { @@ -255,26 +255,26 @@ class GraphOperation : public TracingOperation { } op_->node_builder.Attr(attr_name, v); } - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFloatList(const char* attr_name, const float* values, int num_values) override { op_->node_builder.Attr(attr_name, ArraySlice(values, num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrIntList(const char* attr_name, const int64_t* values, int num_values) override { op_->node_builder.Attr( attr_name, ArraySlice( reinterpret_cast(values), num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrTypeList(const char* attr_name, const DataType* values, int num_values) override { op_->node_builder.Attr(attr_name, ArraySlice(values, num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrBoolList(const char* attr_name, const unsigned char* values, int num_values) override { @@ -285,7 +285,7 @@ class GraphOperation : public TracingOperation { op_->node_builder.Attr(attr_name, ArraySlice(b.get(), num_values)); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrShapeList(const char* attr_name, const int64_t** dims, const int* num_dims, int num_values) override { @@ -300,7 +300,7 @@ class GraphOperation : public TracingOperation { } } op_->node_builder.Attr(attr_name, shapes); - return OkStatus(); + return absl::OkStatus(); } Status SetAttrFunctionList( const char* attr_name, @@ -368,7 +368,7 @@ class GraphContext : public TracingContext { } inputs_.push_back(t->output_); *output = tensorflow::down_cast(outputs[0]); - return OkStatus(); + return absl::OkStatus(); } Status Finalize(OutputList* outputs, AbstractFunction** f) override { @@ -393,7 +393,7 @@ class GraphContext : public TracingContext { TF_DeleteFunction(func); TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_DeleteStatus(s); - return OkStatus(); + return absl::OkStatus(); } Status RegisterFunction(AbstractFunction* func) override { diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index ca033cd2266b01..f8de31aadbaa6f 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -169,57 +169,57 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype, "Only DLPack bools of bitwidth 8 are supported, got: ", dtype.bits); } *tf_dtype = TF_DataType::TF_BOOL; - return OkStatus(); + return absl::OkStatus(); case DLDataTypeCode::kDLUInt: switch (dtype.bits) { case 8: *tf_dtype = TF_DataType::TF_UINT8; - return OkStatus(); + return absl::OkStatus(); case 16: *tf_dtype = TF_DataType::TF_UINT16; - return OkStatus(); + return absl::OkStatus(); case 32: *tf_dtype = TF_DataType::TF_UINT32; - return OkStatus(); + return absl::OkStatus(); case 64: *tf_dtype = TF_DataType::TF_UINT64; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ", dtype.bits); } - return OkStatus(); + return absl::OkStatus(); case DLDataTypeCode::kDLInt: switch (dtype.bits) { case 8: *tf_dtype = TF_DataType::TF_INT8; - return OkStatus(); + return absl::OkStatus(); case 16: *tf_dtype = TF_DataType::TF_INT16; - return OkStatus(); + return absl::OkStatus(); case 32: *tf_dtype = TF_DataType::TF_INT32; - return OkStatus(); + return absl::OkStatus(); case 64: *tf_dtype = TF_DataType::TF_INT64; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument("Unsupported Int bits: ", dtype.bits); } - return OkStatus(); + return absl::OkStatus(); case DLDataTypeCode::kDLFloat: switch (dtype.bits) { case 16: *tf_dtype = TF_DataType::TF_HALF; - return OkStatus(); + return absl::OkStatus(); case 32: *tf_dtype = TF_DataType::TF_FLOAT; - return OkStatus(); + return absl::OkStatus(); case 64: *tf_dtype = TF_DataType::TF_DOUBLE; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument("Unsupported Float bits: ", dtype.bits); @@ -229,7 +229,7 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype, switch (dtype.bits) { case 16: *tf_dtype = TF_DataType::TF_BFLOAT16; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument( "Unsupported BFloat bits: ", dtype.bits); @@ -239,10 +239,10 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype, switch (dtype.bits) { case 64: *tf_dtype = TF_DataType::TF_COMPLEX64; - return OkStatus(); + return absl::OkStatus(); case 128: *tf_dtype = TF_DataType::TF_COMPLEX128; - return OkStatus(); + return absl::OkStatus(); default: return tensorflow::errors::InvalidArgument( "Unsupported Complex bits: ", dtype.bits); diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index 2042c857e8b211..2fcaee07b37f50 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -65,7 +65,7 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, // If the output is a scalar, then return the scalar output if (num_dims_out == 0) { outputs[0] = model_out.release(); - return OkStatus(); + return absl::OkStatus(); } // Else, reduce sum the output to get a scalar @@ -85,7 +85,7 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, // Reduce sum the output on all dimensions. TF_RETURN_IF_ERROR(ops::Sum(ctx, model_out.get(), sum_dims.get(), &outputs[0], /*keep_dims=*/false, "sum_output")); - return OkStatus(); + return absl::OkStatus(); } // ========================= End Helper Functions============================== @@ -198,7 +198,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, TF_RETURN_IF_ERROR(TestTensorHandleWithDims( ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); TF_DeleteTensor(theta_tensor); - return OkStatus(); + return absl::OkStatus(); } } // namespace gradients diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 6f4ab3016beb63..326a9e8cb829d4 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -47,7 +47,7 @@ Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, TF_RETURN_IF_ERROR( op->Execute(absl::Span(outputs), &num_outputs)); *result = outputs[0]; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -59,7 +59,7 @@ Status GradientRegistry::Register( return errors::AlreadyExists(error_msg); } registry_.insert({op_name, gradient_function_factory}); - return OkStatus(); + return absl::OkStatus(); } Status GradientRegistry::Lookup( const ForwardOperation& op, @@ -70,7 +70,7 @@ Status GradientRegistry::Lookup( return errors::NotFound(error_msg); } gradient_function->reset(iter->second(op)); - return OkStatus(); + return absl::OkStatus(); } TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) { @@ -200,7 +200,7 @@ Status TapeVSpace::BuildOnesLike(const TapeTensor& t, TF_RETURN_IF_ERROR( op->Execute(absl::Span(outputs), &num_outputs)); *result = outputs[0]; - return OkStatus(); + return absl::OkStatus(); } // Looks up the ID of a Gradient. @@ -292,7 +292,7 @@ Status Tape::ComputeGradient( TF_RETURN_IF_ERROR(GradientTape::ComputeGradient( vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets, output_gradients, result, /*build_default_zeros_grads*/ false)); - return OkStatus(); + return absl::OkStatus(); } // Helper functions which delegate to `AbstractOperation`, update @@ -309,7 +309,7 @@ Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, ForwardOperation* forward_op_) { TF_RETURN_IF_ERROR(op_->AddInput(input)); forward_op_->inputs.push_back(input); - return OkStatus(); + return absl::OkStatus(); } Status AddInputList(AbstractOperation* op_, absl::Span inputs, @@ -318,7 +318,7 @@ Status AddInputList(AbstractOperation* op_, for (auto input : inputs) { forward_op_->inputs.push_back(input); } - return OkStatus(); + return absl::OkStatus(); } Status SetAttrString(AbstractOperation* op_, const char* attr_name, @@ -482,7 +482,7 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn)); tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(), op_->Name()); - return OkStatus(); + return absl::OkStatus(); } } // namespace internal diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index a345240e8c3e4f..9df16f10290d0b 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -59,7 +59,7 @@ class CppGradients Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics")); - return OkStatus(); + return absl::OkStatus(); } TEST_P(CppGradients, TestSetAttrString) { diff --git a/tensorflow/c/eager/graph_function.cc b/tensorflow/c/eager/graph_function.cc index 3f4430bb614ea1..bf45feb34afb0f 100644 --- a/tensorflow/c/eager/graph_function.cc +++ b/tensorflow/c/eager/graph_function.cc @@ -22,7 +22,7 @@ GraphFunction::GraphFunction(FunctionDef fdef) GraphFunction::~GraphFunction() {} Status GraphFunction::GetFunctionDef(FunctionDef** fdef) { *fdef = &fdef_; - return OkStatus(); + return absl::OkStatus(); } } // namespace graph } // namespace tracing diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.cc b/tensorflow/c/eager/immediate_execution_tensor_handle.cc index d8cb9e165495c1..c99a270f0cb804 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.cc +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.cc @@ -55,7 +55,7 @@ Status ImmediateExecutionTensorHandle::SummarizeValue( return status; } summary = resolved->SummarizeValue(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 0522ad3b73072f..e5b1ee97a2e802 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -604,7 +604,7 @@ Status ParallelTensor::Shape(const std::vector** shape) const { shape_ = std::vector(dim_sizes.begin(), dim_sizes.end()); } *shape = &*shape_; - return OkStatus(); + return absl::OkStatus(); } Status ParallelTensor::SummarizeValue(std::string& summary) { @@ -624,7 +624,7 @@ Status ParallelTensor::SummarizeValue(std::string& summary) { "\": ", component_summary); } summary += "}"; - return OkStatus(); + return absl::OkStatus(); } } // namespace parallel_device diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 48bdcb2c9a26bf..c0b62760cd4207 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -80,6 +80,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/stream_executor", ], ) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 12391143a4d9e0..e20cfcfd83a205 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -573,11 +573,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return std::unique_ptr( new CEvent(&device_, stream_executor_)); } - std::unique_ptr CreateKernelImplementation() - override { - LOG(FATAL) - << "CreateKernelImplementation is not supported by pluggable device."; - } std::unique_ptr GetStreamImplementation() override { return std::unique_ptr( diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 0f3e2e76aa4ebe..3542586ffa8e4e 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace { @@ -200,11 +201,11 @@ TEST_F(StreamExecutorTest, HostMemoryAllocate) { }; StreamExecutor* executor = GetExecutor(0); ASSERT_FALSE(allocate_called); - void* mem = executor->HostMemoryAllocate(8); - ASSERT_NE(mem, nullptr); + TF_ASSERT_OK_AND_ASSIGN(auto mem, executor->HostMemoryAllocate(8)); + ASSERT_NE(mem->opaque(), nullptr); ASSERT_TRUE(allocate_called); ASSERT_FALSE(deallocate_called); - executor->HostMemoryDeallocate(mem); + mem.reset(); ASSERT_TRUE(deallocate_called); } diff --git a/tensorflow/cc/framework/fuzzing/BUILD b/tensorflow/cc/framework/fuzzing/BUILD index 772bff55a9251c..ec424fc0425630 100644 --- a/tensorflow/cc/framework/fuzzing/BUILD +++ b/tensorflow/cc/framework/fuzzing/BUILD @@ -29,7 +29,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:hash", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status", ], diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc index 416bb56e820359..cacc15ca32d28f 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc @@ -26,14 +26,19 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "tensorflow/cc/framework/cc_op_gen_util.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/hash.h" #include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" namespace tensorflow { diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc index b76be0fd608715..8b2570fdc953a7 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc @@ -19,15 +19,17 @@ limitations under the License. #include #include -#include "absl/status/status.h" #include "tensorflow/cc/framework/cc_op_gen_util.h" #include "tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/status.h" diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 935b37b37aa5c0..4d8d5dfa11da87 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -109,7 +109,7 @@ cc_library( "//tensorflow/core:direct_session", "//tensorflow/core:all_kernels", ] + if_google( - ["@local_tsl//tsl/platform/default/build_config:tensorflow_platform_specific"], + ["//tensorflow/core/platform/default/build_config:tensorflow_platform_specific"], [], )) + if_not_mobile([ "//tensorflow/core:core_cpu", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index a245bf59a1f187..ae63fdab2fa32c 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -104,7 +104,7 @@ static Status ValidateNode(const NodeDef& node) { "Saved model contains node \"", node.name(), "\" which is a constant tensor but no value has been provided")); } - return OkStatus(); + return absl::OkStatus(); } static Status ValidateFunctionNotRecursive(const FunctionDef& function) { @@ -117,7 +117,7 @@ static Status ValidateFunctionNotRecursive(const FunctionDef& function) { } } - return OkStatus(); + return absl::OkStatus(); } static Status ValidateSavedTensors(const GraphDef& graph_def) { @@ -137,7 +137,7 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) { } } - return OkStatus(); + return absl::OkStatus(); } Tensor CreateStringTensor(const string& value) { @@ -223,7 +223,7 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir, return RunOnce(run_options, inputs, {}, {init_op_name}, nullptr /* outputs */, &run_metadata, session); } - return OkStatus(); + return absl::OkStatus(); } Status RunRestore(const RunOptions& run_options, const string& export_dir, @@ -247,7 +247,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " "were restored. File does not exist: " << variables_index_path; - return OkStatus(); + return absl::OkStatus(); } const string variables_path = io::JoinPath(variables_directory, kSavedModelVariablesFilename); @@ -293,7 +293,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, session_options, bundle->meta_graph_def, &bundle->session)); TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def, export_dir, &bundle->session)); - return OkStatus(); + return absl::OkStatus(); } Status LoadSavedModel(const SessionOptions& session_options, @@ -469,7 +469,7 @@ Status RestoreSession(const RunOptions& run_options, // Record wall time spent in init op. load_latency_by_stage->GetCell(export_dir, "init_graph") ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); - return OkStatus(); + return absl::OkStatus(); } Status LoadSavedModel(const SessionOptions& session_options, @@ -494,7 +494,7 @@ Status LoadSavedModel(const SessionOptions& session_options, *bundle = SavedModelBundleLite( std::make_unique(std::move(legacy_bundle.session)), std::move(*legacy_bundle.meta_graph_def.mutable_signature_def())); - return OkStatus(); + return absl::OkStatus(); } bool MaybeSavedModelDirectory(const string& export_dir) { diff --git a/tensorflow/cc/saved_model/loader_util.cc b/tensorflow/cc/saved_model/loader_util.cc index e17f2ed4abb690..3a984bf31b3cd9 100644 --- a/tensorflow/cc/saved_model/loader_util.cc +++ b/tensorflow/cc/saved_model/loader_util.cc @@ -42,7 +42,7 @@ Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, kSavedModelInitOpSignatureKey); } *init_op_name = sig_def_outputs_it->second.name(); - return OkStatus(); + return absl::OkStatus(); } const auto& collection_def_map = meta_graph_def.collection_def(); @@ -62,7 +62,7 @@ Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, } *init_op_name = init_op_it->second.node_list().value(0); } - return OkStatus(); + return absl::OkStatus(); } Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, @@ -73,13 +73,13 @@ Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, for (const auto& asset : meta_graph_def.asset_file_def()) { asset_file_defs->push_back(asset); } - return OkStatus(); + return absl::OkStatus(); } // Fall back to read from collection to be backward compatible with v1. const auto& collection_def_map = meta_graph_def.collection_def(); const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); if (assets_it == collection_def_map.end()) { - return OkStatus(); + return absl::OkStatus(); } const auto& any_assets = assets_it->second.any_list().value(); for (const auto& any_asset : any_assets) { @@ -88,7 +88,7 @@ Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")); asset_file_defs->push_back(asset_file_def); } - return OkStatus(); + return absl::OkStatus(); } } // namespace internal diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index 5563439f290391..b90f84438b3abf 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -63,7 +63,7 @@ Status FindMetaGraphDef(const std::unordered_set& tags, if (!port::kLittleEndian) { TF_RETURN_IF_ERROR(ByteSwapTensorContentInMetaGraphDef(meta_graph_def)); } - return OkStatus(); + return absl::OkStatus(); } } return Status( @@ -137,7 +137,7 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir, TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); TF_RETURN_IF_ERROR( FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def)); - return OkStatus(); + return absl::OkStatus(); } Status ReadSavedModelDebugInfoIfPresent( @@ -156,7 +156,7 @@ Status ReadSavedModelDebugInfoIfPresent( ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); *debug_info_proto = std::make_unique(std::move(debug_info)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/util.cc b/tensorflow/cc/saved_model/util.cc index 3e0b1eb27026bb..b474f1ef3ed0f3 100644 --- a/tensorflow/cc/saved_model/util.cc +++ b/tensorflow/cc/saved_model/util.cc @@ -86,7 +86,7 @@ Status GetInputValues( absl::StrJoin(seen_request_inputs, ","), ", request input: ", absl::StrJoin(GetMapKeys(request_inputs), ","))); } - return OkStatus(); + return absl::OkStatus(); } } // namespace saved_model diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 8eb99e2b46fa09..ee7f9e4805a1aa 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -1,7 +1,7 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "if_google", "if_oss", "tf_cc_binary", "tf_cc_test") +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available", "if_llvm_system_z_available") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +18,7 @@ cc_library( deps = if_oss([ "//tensorflow/core:test_main", ]) + if_google([ - "@local_tsl//tsl/platform/default/build_config:test_main", + "//tensorflow/core/platform/default/build_config:test_main", ]), ) diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index fcd230f4ab296a..f7273c091a4d37 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -153,7 +153,7 @@ static void CompareWithGoldenFile( // To update the golden file, flip update_golden to true and run the // following: - // bazel test --test_strategy=local \ + // blaz test --test_strategy=local \ // "third_party/tensorflow/compiler/aot:codegen_test" const bool update_golden = false; string golden_file_name = @@ -230,7 +230,7 @@ TEST(CodegenTest, Golden) { /*result_param_number=*/1), BufferInfo::MakeResultParameter(/*size=*/5 * 4, /*result_param_number=*/2)}, - 0, {})); + 0, nullptr, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 1131d95fdb18ad..71b234a8385806 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -462,7 +462,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const ::xla::ProgramShapeProto* StaticProgramShape() { static const ::xla::ProgramShapeProto* kShape = []() { ::xla::ProgramShapeProto* proto = new ::xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 133); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 157); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index 55178e3ee7130a..28910275330f8c 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a180ee12e22e9b..9e3fd27b8f6e86 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1814,9 +1814,6 @@ tf_cc_test( tf_cuda_cc_test( name = "device_compiler_test", srcs = ["device_compiler_test.cc"], - env = { - "XLA_FLAGS": "--xla_gpu_enable_xla_runtime_executable", - }, tags = [ "config-cuda-only", "no_oss", # This test only runs with GPU. diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index 05c3f35a0042b8..fc562d8277a77e 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/client/local_client.h" +#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/mutex.h" @@ -317,6 +318,7 @@ DeviceCompiler::CompileStrict( cache_value.compilation_status = loaded_executable->status(); if (loaded_executable->ok()) { out_executable = *std::move(*loaded_executable); + metrics::UpdatePersistentCacheLoadCount(); } } else { auto built_executable = diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 82ed25767b90de..f85fd5fde4c1fa 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -282,6 +282,7 @@ void AllocateAndParseFlags() { bool enable_mlir_bridge_is_explicit = false; bool enable_mlir_merge_control_flow_pass = true; bool enable_mlir_convert_control_to_data_outputs_pass = false; + bool enable_mlir_composite_tpuexecute_side_effects = false; bool enable_mlir_strict_clusters = false; bool enable_mlir_multiple_local_cpu_devices = false; // Dump graphs in TFG dialect. @@ -376,6 +377,10 @@ void AllocateAndParseFlags() { &enable_mlir_convert_control_to_data_outputs_pass, "Enables `tf-executor-convert-control-to-data-outputs` pass for " "MLIR-Based TensorFlow Compiler Bridge."), + Flag("tf_mlir_composite_tpuexecute_side_effects", + &enable_mlir_composite_tpuexecute_side_effects, + "Enables certain TPUExecute ops to run in parallel if they only " + "operate on resources that live on composite devices."), Flag("tf_mlir_enable_strict_clusters", &enable_mlir_strict_clusters, "Do not allow clusters that have cyclic control dependencies."), Flag("tf_mlir_enable_multiple_local_cpu_devices", @@ -414,6 +419,8 @@ void AllocateAndParseFlags() { enable_mlir_merge_control_flow_pass; mlir_flags->tf_mlir_enable_convert_control_to_data_outputs_pass = enable_mlir_convert_control_to_data_outputs_pass; + mlir_flags->tf_mlir_enable_composite_tpuexecute_side_effects = + enable_mlir_composite_tpuexecute_side_effects; mlir_flags->tf_mlir_enable_strict_clusters = enable_mlir_strict_clusters; mlir_flags->tf_mlir_enable_generic_outside_compilation = enable_mlir_generic_outside_compilation; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 45a4c83a614afd..d2c078a617b258 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -288,6 +288,7 @@ struct MlirCommonFlags { bool tf_mlir_enable_merge_control_flow_pass; bool tf_mlir_enable_convert_control_to_data_outputs_pass; + bool tf_mlir_enable_composite_tpuexecute_side_effects; bool tf_mlir_enable_strict_clusters; bool tf_mlir_enable_generic_outside_compilation; bool tf_mlir_enable_tpu_variable_runtime_reformatting_pass; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 435b63d8f5dbe9..6ba7afe9884e91 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -212,7 +212,7 @@ Status GetTaskName(const std::string_view device_name, std::string* task_name) { device_name); } - return OkStatus(); + return absl::OkStatus(); } // Provide SendDeviceMemoryFunction for XLA host callbacks. This callback @@ -400,7 +400,7 @@ Status CompileToLocalExecutable( rm->default_container(), "device_compilation_profiler", &profiler, [](DeviceCompilationProfiler** profiler) { *profiler = new DeviceCompilationProfiler(); - return OkStatus(); + return absl::OkStatus(); })); // Hold the reference to the XLA device compiler and profiler during // evaluation. (We could probably free them sooner because the ResourceMgr @@ -899,7 +899,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { closure.client(), closure.executable(), ctx)); } - OP_REQUIRES_OK(ctx, OkStatus()); + OP_REQUIRES_OK(ctx, absl::OkStatus()); return; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 526059c22fde8b..bf5a0ce68dd4a4 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -637,7 +637,7 @@ Status IgnoreResourceOpForSafetyAnalysis( if (n.assigned_device_name().empty()) { *ignore = false; - return OkStatus(); + return absl::OkStatus(); } TF_ASSIGN_OR_RETURN( @@ -649,7 +649,7 @@ Status IgnoreResourceOpForSafetyAnalysis( } else { *ignore = registration->cluster_resource_variable_ops_unsafely; } - return OkStatus(); + return absl::OkStatus(); } StatusOr MarkForCompilationPassImpl::Initialize() { @@ -892,7 +892,7 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { })); TF_RET_CHECK(!changed); - return OkStatus(); + return absl::OkStatus(); } Status MarkForCompilationPassImpl::DeclusterNodes() { @@ -922,7 +922,7 @@ Status MarkForCompilationPassImpl::DeclusterNodes() { } } - return OkStatus(); + return absl::OkStatus(); } // Tracks monotonic sequence numbers for graphs. @@ -1010,7 +1010,7 @@ Status MarkForCompilationPassImpl::CreateClusters() { } } - return OkStatus(); + return absl::OkStatus(); } Status MarkForCompilationPassImpl::DumpDebugInfo() { @@ -1022,7 +1022,7 @@ Status MarkForCompilationPassImpl::DumpDebugInfo() { VLogClusteringSummary(); - return OkStatus(); + return absl::OkStatus(); } StatusOr @@ -1181,7 +1181,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { cluster_for_node_[node->id()].Get() = new_cluster; } - return OkStatus(); + return absl::OkStatus(); } StatusOr IsIdentityDrivingConstsInLoop(Node* node) { @@ -1475,7 +1475,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { VLOG(2) << "compilation_candidates_.size() = " << compilation_candidates_.size(); - return OkStatus(); + return absl::OkStatus(); } bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( @@ -1596,7 +1596,7 @@ Status MarkForCompilationPassImpl::Run() { if (!initialized) { // Initialization exited early which means this instance of // MarkForCompilationPassImpl is not set up to run the subsequent phases. - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(RunEdgeContractionLoop()); @@ -1604,7 +1604,7 @@ Status MarkForCompilationPassImpl::Run() { TF_RETURN_IF_ERROR(CreateClusters()); TF_RETURN_IF_ERROR(DumpDebugInfo()); - return OkStatus(); + return absl::OkStatus(); } void MarkForCompilationPassImpl::DumpPostClusteringGraphs() { @@ -1864,14 +1864,14 @@ Status MarkForCompilation( for (Node* n : graph->nodes()) { // See explanation on `kXlaAlreadyClustered`. if (n->attrs().Find(kXlaAlreadyClustered)) { - return OkStatus(); + return absl::OkStatus(); } // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops // in the graph, which indicates the graph is produced by TPU TF-XLA bridge // and doesn't require auto clustering. if (n->type_string() == "TPUExecute" || n->type_string() == "TPUExecuteAndUpdateVariables") { - return OkStatus(); + return absl::OkStatus(); } } @@ -2277,6 +2277,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "VariableShape", "Where", "While", + "XlaAllReduce", "XlaBroadcastHelper", "XlaCallModule", "XlaConcatND", @@ -2298,6 +2299,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaRecv", "XlaReduce", "XlaReducePrecision", + "XlaReduceScatter", "XlaReduceWindow", "XlaRemoveDynamicDimensionSize", "XlaReplicaId", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b9527ae9ec56b0..aabedf61202d3f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -390,7 +390,7 @@ static Status GradForUnaryCwise(FunctionDef* g, {}, // Nodes nodes); - return OkStatus(); + return absl::OkStatus(); } // A gradient containing only supported operators @@ -1816,7 +1816,7 @@ TEST(XlaCompilationTest, DeterministicClusterNames) { " rhs: ", rhs_cluster_name); } - return OkStatus(); + return absl::OkStatus(); }; testing::ResetClusterSequenceNumber(); diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index cf0e03bc9f64e5..7c370e46dec63f 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -55,7 +55,7 @@ REGISTER_OP("XlaClusterOutput") for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->input(0)); } - return OkStatus(); + return absl::OkStatus(); }) .Doc( "Operator that connects the output of an XLA computation to other " @@ -112,7 +112,7 @@ REGISTER_OP("_XlaMerge") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(0)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"(XLA Merge Op. For use by the XLA JIT only. diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 715cc1b31738b5..eb66a8d905cc8c 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -113,7 +113,7 @@ Status FindNodesToDecluster(const Graph& graph, } } } - return OkStatus(); + return absl::OkStatus(); } Status PartiallyDeclusterNode(Graph* graph, Node* n) { @@ -156,7 +156,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { graph->RemoveNode(n); } - return OkStatus(); + return absl::OkStatus(); } // Clones nodes to outside their cluster to avoid device-to-host copies. For @@ -221,7 +221,7 @@ Status PartiallyDeclusterGraph(Graph* graph) { FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); CHECK(nodes_to_partially_decluster.empty()); - return OkStatus(); + return absl::OkStatus(); } } // namespace reduce_device_to_host_copies @@ -251,12 +251,12 @@ Status MustCompileNode(const Node* n, bool* must_compile) { if (IsMustCompileDevice(device_type)) { *must_compile = true; - return OkStatus(); + return absl::OkStatus(); } // We must compile `n` if it does not have a TensorFlow kernel. *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok(); - return OkStatus(); + return absl::OkStatus(); } // Declusters nodes to reduce the number of times we think we need to recompile @@ -363,7 +363,7 @@ Status PartiallyDeclusterGraph(Graph* graph, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace reduce_recompilation @@ -397,7 +397,7 @@ Status PartiallyDeclusterGraph(Graph* graph) { << " because it is a root shape consumer"; RemoveFromXlaCluster(n); } - return OkStatus(); + return absl::OkStatus(); } } // namespace decluster_root_shape_consumers } // namespace @@ -430,6 +430,6 @@ Status PartiallyDeclusterPass::Run( TF_RETURN_IF_ERROR( decluster_root_shape_consumers::PartiallyDeclusterGraph(graph)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index cd8d231af19f5b..a403557f9a9087 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -96,7 +96,7 @@ void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, profiler::TraceMe traceme("PjRtDeviceContext::CopyDeviceTensorToCPU"); if (device_tensor->NumElements() == 0) { VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } auto literal = std::make_unique(); @@ -149,7 +149,7 @@ void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, profiler::TraceMe traceme("PjRtDeviceContext::CopyCPUTensorToDevice"); if (cpu_tensor->NumElements() == 0) { VLOG(2) << "CopyCPUTensorToDevice empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } @@ -243,7 +243,7 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, profiler::TraceMe traceme("PjRtDevice_DeviceToDeviceCopy"); if (input->NumElements() == 0) { VLOG(2) << "PjRtDevice_DeviceToDeviceCopy empty tensor"; - done(OkStatus()); + done(absl::OkStatus()); return; } diff --git a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc index 9cf49b0666ddd5..ffbcef3371ae81 100644 --- a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc +++ b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc @@ -118,7 +118,7 @@ TEST(RearrangeFunctionArgumentForFunctionTest, Basic) { &fld, &new_fbody)); *fbody = new_fbody.get(); fbodies.push_back(std::move(new_fbody)); - return OkStatus(); + return absl::OkStatus(); }, g.get(), &fld)); @@ -229,7 +229,7 @@ TEST(RearrangeFunctionArgumentForFunctionTest, &fld, &new_fbody)); *fbody = new_fbody.get(); fbodies.push_back(std::move(new_fbody)); - return OkStatus(); + return absl::OkStatus(); }, g.get(), &fld); EXPECT_EQ(status.code(), error::UNIMPLEMENTED); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 650369863c4d05..92f79dde874217 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -106,13 +106,13 @@ Status XlaResourceOpKindForNode( } if (should_ignore) { *out_resource_op_kind = std::nullopt; - return OkStatus(); + return absl::OkStatus(); } const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); if (op_info) { *out_resource_op_kind = op_info->kind(); - return OkStatus(); + return absl::OkStatus(); } // We conservatively assume that functions will both read and write resource @@ -124,7 +124,7 @@ Status XlaResourceOpKindForNode( *out_resource_op_kind = std::nullopt; } - return OkStatus(); + return absl::OkStatus(); } // Returns true if a control or data dependence from a TensorFlow operation of @@ -314,6 +314,6 @@ Status ComputeIncompatibleResourceOperationPairs( std::sort(result->begin(), result->end()); CHECK(std::unique(result->begin(), result->end()) == result->end()); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index b9cbd1e3105b13..9a6bb729149dbd 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -33,7 +33,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, const shape_inference::ShapeHandle& handle, PartialTensorShape* shape) { // The default is already unknown - if (!context->RankKnown(handle)) return OkStatus(); + if (!context->RankKnown(handle)) return absl::OkStatus(); std::vector dims(context->Rank(handle)); for (int32_t i = 0, end = dims.size(); i < end; ++i) { @@ -199,7 +199,7 @@ Status PropagateShapes(Graph* graph, } } } - return OkStatus(); + return absl::OkStatus(); } // Store the shapes of the output tensors in a map @@ -235,7 +235,7 @@ Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner, << output.handle_shape.DebugString(); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/compiler/jit/shape_inference_helpers.cc b/tensorflow/compiler/jit/shape_inference_helpers.cc index f3dd0c7ec78453..9290861d48f0bc 100644 --- a/tensorflow/compiler/jit/shape_inference_helpers.cc +++ b/tensorflow/compiler/jit/shape_inference_helpers.cc @@ -41,7 +41,7 @@ Status BackEdgeHelper::Remove(Graph* graph) { for (const BackEdge& be : back_edges_) { graph_->RemoveEdge(be.edge); } - return OkStatus(); + return absl::OkStatus(); } const std::vector& BackEdgeHelper::RemovedEdges() @@ -60,7 +60,7 @@ Status BackEdgeHelper::Replace() { for (const BackEdge& be : back_edges_) { graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc index 41afd63cca3b1e..f073902bc03d4a 100644 --- a/tensorflow/compiler/jit/test_util.cc +++ b/tensorflow/compiler/jit/test_util.cc @@ -58,7 +58,7 @@ Status ShapeAnnotationsMatch( return errors::InvalidArgument("Missing shapes for nodes: ", absl::StrJoin(missing, ",")); } - return OkStatus(); + return absl::OkStatus(); } void DeviceSetup::AddDevicesAndSetUp( diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index 2d382cbda5f7f0..e9880013bf2611 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -88,9 +88,6 @@ tf_cc_test( srcs = [ "device_compiler_serialize_test.cc", ], - env = { - "XLA_FLAGS": "--xla_gpu_enable_xla_runtime_executable", - }, tags = [ "config-cuda-only", "no_oss", # This test only runs with GPU. @@ -110,9 +107,6 @@ tf_cc_test( srcs = [ "device_compiler_serialize_options_test.cc", ], - env = { - "XLA_FLAGS": "--xla_gpu_enable_xla_runtime_executable", - }, tags = [ "config-cuda-only", "no_oss", # This test only runs with GPU. diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index e93af0df217e84..8833ce69d4bf6b 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -95,7 +95,7 @@ Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) { } } - return OkStatus(); + return absl::OkStatus(); } Status ReadTextProtoFromString(Env* env, const string& data, @@ -103,7 +103,7 @@ Status ReadTextProtoFromString(Env* env, const string& data, if (!::tensorflow::protobuf::TextFormat::ParseFromString(data, proto)) { return errors::DataLoss("Can't parse input data as text proto"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -121,7 +121,7 @@ Status AutoClusteringTest::RunAutoClusteringTestImpl( LOG(INFO) << "Not running " << ::testing::UnitTest::GetInstance()->current_test_info()->name() << " since test was not built with --config=cuda"; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(AssertGraphDefIsUnclustered(graphdef)); @@ -158,7 +158,7 @@ Status AutoClusteringTest::RunAutoClusteringTestImpl( EXPECT_EQ(golden_file_contents, clustering_summary); - return OkStatus(); + return absl::OkStatus(); } Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt( @@ -221,7 +221,7 @@ Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, std::move(graph_def_copy), &result)); } - return OkStatus(); + return absl::OkStatus(); } #endif // PLATFORM_GOOGLE diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc index c1ae143ee94508..4eb93e85819651 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc @@ -131,7 +131,7 @@ Status DeviceCompilerSerializeTest::ExecuteWithBatch(const GraphDef& graph, EXPECT_NEAR(golden_output_tensors[0].flat()(i), output_tensors[0].flat()(i), 1e-3); } - return OkStatus(); + return absl::OkStatus(); } Status DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( @@ -160,7 +160,7 @@ Status DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( return errors::NotFound( "Did not find any persistent XLA compilation cache entries to alter."); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h index e8ae70928d17d4..9cf36d0cbc6cb1 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h @@ -33,17 +33,17 @@ class JitCompilationListener : public XlaActivityListener { public: Status Listen( const XlaAutoClusteringActivity& auto_clustering_activity) override { - return OkStatus(); + return absl::OkStatus(); } Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { activity_history_.push_back(jit_compilation_activity); - return OkStatus(); + return absl::OkStatus(); } Status Listen(const XlaOptimizationRemark& optimization_remark) override { - return OkStatus(); + return absl::OkStatus(); } ~JitCompilationListener() override = default; @@ -55,7 +55,7 @@ class JitCompilationListener : public XlaActivityListener { return absl::FailedPreconditionError("Unexpected listener history."); } } - return OkStatus(); + return absl::OkStatus(); } std::vector GetListenerHistory() { diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 6f71814bbe7b6a..c47096796d7f86 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -501,38 +501,6 @@ Status XlaDevice::Sync() { return OkStatus(); } -// TODO(b/112409994): This is no longer necessary. Consolidate it with the -// synchronous version. -void XlaDevice::Sync(const DoneCallback& done) { - VLOG(1) << "XlaDevice::Sync (asynchronous)"; - std::shared_ptr stream; - { - mutex_lock lock(mu_); - stream = stream_; - } - if (!stream) { - done(OkStatus()); - return; - } - - // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at - // the end of the stream, after everything that has already been enqueued - // there at this moment. When the host callback is called, everything before - // it must have already finished, and the host callback will then place the - // task below onto a background thread. (See the implementation of - // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done - // callback is finally called from that background thread, we know for sure - // that everything enqueued onto the stream (i.e., the device) at this very - // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. - // This achieves a device-wide sync. - stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) { - profiler::TraceMe activity("XlaDevice::Sync::Callback", - profiler::TraceMeLevel::kInfo); - done(stream->ok() ? OkStatus() - : errors::Internal("XlaDevice::Sync() failed.")); - }); -} - Status XlaDevice::MakeTensorFromProto(DeviceContext* device_context, const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index aeff4501af480d..ecfdd073253a80 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -158,7 +158,6 @@ class XlaDevice : public LocalDevice { void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override; - void Sync(const DoneCallback& done) override; Status TryGetDeviceContext(DeviceContext** out_context) override TF_LOCKS_EXCLUDED(mu_); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 288344eb341d23..dff81fdd367fcc 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -673,7 +673,6 @@ Status PreparePjRtExecutableArguments( } } else { if (av_tensor->GetBuffer() == nullptr) { - // TODO(b/260799971): verify size 0 argument is supported. CHECK_EQ(tensor->NumElements(), 0); // Crash OK continue; } diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 742f97d902522a..b30f08a1bfe1b4 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -74,7 +74,6 @@ cc_library( "@local_xla//xla/mlir/framework/transforms:passes", "@local_xla//xla/mlir_hlo:all_passes", "@local_xla//xla/service/cpu:hlo_xla_runtime_pipeline", - "@local_xla//xla/translate/mhlo_to_lhlo_with_xla", ], ) @@ -204,7 +203,6 @@ cc_library( "@local_xla//xla/mlir/framework/ir:xla_framework", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/service/cpu:hlo_xla_runtime_pipeline", - "@local_xla//xla/translate/mhlo_to_lhlo_with_xla", "@stablehlo//:register", ], ) @@ -239,7 +237,6 @@ tf_cc_binary( "@llvm-project//mlir:TranslateLib", "@local_xla//xla/translate/hlo_to_mhlo:translate_registration", "@local_xla//xla/translate/mhlo_to_hlo:translate_registration", - "@local_xla//xla/translate/mhlo_to_lhlo_with_xla:translate_registration", ], ) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 3ebbd3b1b81942..38610dcef42b52 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -5,7 +5,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_po load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -452,14 +452,15 @@ cc_library( "utils/constant_utils.h", ], deps = [ - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", ], ) @@ -754,6 +755,7 @@ cc_library( "transforms/passes.h", ], deps = [ + ":constant_utils", ":convert_type", ":tensorflow_lite", ":tensorflow_lite_passes_inc_gen", @@ -1267,24 +1269,29 @@ tf_cc_binary( ":tf_to_tfl_flatbuffer", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/platform:statusor", "@local_xla//xla/translate/hlo_to_mhlo:translate", "@stablehlo//:stablehlo_ops", ], @@ -1338,23 +1345,17 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/stablehlo:uniform_quantized_stablehlo_to_tfl_pass", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", "//tensorflow/core:core_cpu_base", - "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:all_passes", "@local_xla//xla/mlir_hlo:mhlo_passes", + "@stablehlo//stablehlo/experimental:experimental_stablehlo_passes", ], ) @@ -1370,45 +1371,54 @@ cc_library( ":tensorflow_lite", ":tf_tfl_passes", "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/debug", + "//tensorflow/compiler/mlir/lite/metrics:error_collector", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/stablehlo:quantization", "//tensorflow/compiler/mlir/lite/stablehlo:op_stat_pass", "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_tfl", "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_util", "//tensorflow/compiler/mlir/lite/stablehlo:transforms", - "//tensorflow/compiler/mlir/lite/stablehlo/serializer:flatbuffer_export", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo:quantize_passes", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", # buildcleaner: keep; prevents undefined reference "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:status", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/remat:metadata_util", "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", + "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:reduced_precision_support", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@flatbuffers//:runtime_cc", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 9dd57f2bdea429..39f81c7a6a770d 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -95,6 +95,10 @@ struct PassConfig { // ops and to convert kernels to quantized kernels wherever appropriate. quant::QDQConversionMode qdq_conversion_mode = quant::QDQConversionMode::kQDQNone; + + // When set to true, StableHLO Quantizer is run. The full configuration for + // the quantizer is at `TocoFlags::quantization_config`. + bool enable_stablehlo_quantizer = false; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc index 1bcb86de9a4a94..127d485b842f94 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc @@ -53,11 +53,7 @@ limitations under the License. #include "tsl/platform/status.h" namespace tensorflow { -namespace { - -using ::testing::HasSubstr; -using ::testing::IsEmpty; -using ::testing::Not; +namespace debug_test { class NopPass : public mlir::PassWrapper> { public: @@ -84,6 +80,15 @@ class AlwaysFailPass void runOnOperation() override { signalPassFailure(); } }; +} // namespace debug_test + +namespace { + +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Not; +using namespace tensorflow::debug_test; + class InitPassManagerTest : public testing::Test { protected: InitPassManagerTest() @@ -179,8 +184,7 @@ TEST_F(InitPassManagerTest, DumpToDir) { TF_ASSERT_OK(tsl::ReadFileToString( tsl::Env::Default(), tsl::io::JoinPath( - dump_dir, - "00000000.main.tensorflow_anonymous_namespace_NopPass_after.mlir"), + dump_dir, "00000000.main.tensorflow_debug_test_NopPass_after.mlir"), &mlir_dump)); EXPECT_THAT(mlir_dump, Not(IsEmpty())); } @@ -190,7 +194,7 @@ TEST_F(InitPassManagerTest, DumpToDir) { tsl::Env::Default(), tsl::io::JoinPath( dump_dir, - "00000000.main.tensorflow_anonymous_namespace_NopPass_before.mlir"), + "00000000.main.tensorflow_debug_test_NopPass_before.mlir"), &mlir_dump)); EXPECT_THAT(mlir_dump, Not(IsEmpty())); } @@ -207,12 +211,10 @@ TEST_F(InitPassManagerTest, PrintIRBeforeEverything) { pm.addPass(std::make_unique()); ASSERT_TRUE(mlir::succeeded(pm.run(*module_))); - EXPECT_THAT( - captured_out, - HasSubstr("IR Dump Before tensorflow::(anonymous namespace)::NopPass")); EXPECT_THAT(captured_out, - Not(HasSubstr( - "IR Dump After tensorflow::(anonymous namespace)::NopPass"))); + HasSubstr("IR Dump Before tensorflow::debug_test::NopPass")); + EXPECT_THAT(captured_out, + Not(HasSubstr("IR Dump After tensorflow::debug_test::NopPass"))); } TEST_F(InitPassManagerTest, PrintIRAfterEverything) { @@ -226,13 +228,11 @@ TEST_F(InitPassManagerTest, PrintIRAfterEverything) { pm.addPass(std::make_unique()); ASSERT_TRUE(mlir::succeeded(pm.run(*module_))); + EXPECT_THAT(captured_out, + HasSubstr("IR Dump After tensorflow::debug_test::MutatePass")); EXPECT_THAT( captured_out, - HasSubstr("IR Dump After tensorflow::(anonymous namespace)::MutatePass")); - EXPECT_THAT( - captured_out, - Not(HasSubstr( - "IR Dump Before tensorflow::(anonymous namespace)::MutatePass"))); + Not(HasSubstr("IR Dump Before tensorflow::debug_test::MutatePass"))); } TEST_F(InitPassManagerTest, PrintIRBeforeAndAfterEverything) { @@ -247,13 +247,10 @@ TEST_F(InitPassManagerTest, PrintIRBeforeAndAfterEverything) { pm.addPass(std::make_unique()); ASSERT_TRUE(mlir::succeeded(pm.run(*module_))); - EXPECT_THAT( - captured_out, - HasSubstr("IR Dump After tensorflow::(anonymous namespace)::MutatePass")); - EXPECT_THAT( - captured_out, - HasSubstr( - "IR Dump Before tensorflow::(anonymous namespace)::MutatePass")); + EXPECT_THAT(captured_out, + HasSubstr("IR Dump After tensorflow::debug_test::MutatePass")); + EXPECT_THAT(captured_out, + HasSubstr("IR Dump Before tensorflow::debug_test::MutatePass")); } TEST_F(InitPassManagerTest, ElideLargeElementAttrs) { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 04a92f3412ba82..25f62ce0981b6f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -2632,8 +2632,7 @@ Translator::CreateMetadataVector() { } else { module_.emitError( "all values in tfl.metadata's dictionary key-value pairs should " - "be " - "string attributes"); + "be string attributes"); return std::nullopt; } } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 26ae90d95a97ea..bd912797d44820 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -40,6 +40,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -1495,6 +1496,8 @@ OwningOpRef tflite::FlatBufferToMlir( bool use_stablehlo_constant = false; + llvm::SmallVector metadata_attrs; + mlir::StringSet<> seen_attr; for (const auto& metadata : model->metadata) { if (metadata->name == tflite::kModelControlDependenciesMetadataKey) { const std::vector& data = model->buffers[metadata->buffer]->data; @@ -1502,15 +1505,28 @@ OwningOpRef tflite::FlatBufferToMlir( reinterpret_cast(data.data()), data.size(), &model_control_dependencies)) { return emitError(base_loc, - "Invalid model_control_dependencies metadata"), + "invalid model_control_dependencies metadata"), nullptr; } - break; + continue; } + + // Skip already seen attributes. Ideally there should be no duplicates here. + if (!seen_attr.try_emplace(metadata->name).second) continue; + // check if the model is serialized using stablehlo constant tensor if (metadata->name == tflite::kModelUseStablehloTensorKey) { use_stablehlo_constant = true; + metadata_attrs.emplace_back(builder.getStringAttr(metadata->name), + builder.getStringAttr("true")); + continue; } + + std::vector buffer = model->buffers[metadata->buffer]->data; + metadata_attrs.emplace_back( + builder.getStringAttr(metadata->name), + builder.getStringAttr(llvm::StringRef( + reinterpret_cast(buffer.data()), buffer.size()))); } std::vector func_names; @@ -1528,18 +1544,15 @@ OwningOpRef tflite::FlatBufferToMlir( builder.getStringAttr(model->description)); } + if (!metadata_attrs.empty()) { + module->setAttr("tfl.metadata", builder.getDictionaryAttr(metadata_attrs)); + } + if (!model->signature_defs.empty()) { module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(builder.getContext())); } - if (use_stablehlo_constant) { - module->setAttr("tfl.metadata", - builder.getDictionaryAttr(builder.getNamedAttr( - tflite::kModelUseStablehloTensorKey, - builder.getStringAttr("true")))); - } - absl::flat_hash_map subgraph_to_signature_map; for (int i = 0; i < model->signature_defs.size(); i++) { diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 665d101a4f19af..db138d74b76048 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -285,14 +285,6 @@ static mlir::Attribute BuildRankedTensorAttr(std::vector shape, return mlir::DenseIntElementsAttr::get(ty, value); } -static mlir::Attribute BuildI64ArrayAttr(std::vector shape, - std::vector value, - mlir::Builder builder) { - // Expand splats. BuildI64ArrayAttr assumes shape.size() == 1. - if (value.size() == 1) value.resize(shape[0], value[0]); - return builder.getDenseI64ArrayAttr(value); -} - static mlir::Attribute BuildF32ArrayAttr(std::vector value, mlir::Builder builder) { std::vector typecast(value.begin(), value.end()); @@ -399,29 +391,22 @@ void BuiltinOptions2ToAttributesManual( if (const auto* op = op_union.AsStablehloBroadcastInDimOptions()) { attributes.emplace_back(builder.getNamedAttr( "broadcast_dimensions", - BuildRankedTensorAttr( - {static_cast(op->broadcast_dimensions.size())}, - op->broadcast_dimensions, builder))); + builder.getDenseI64ArrayAttr(op->broadcast_dimensions))); return; } if (const auto* op = op_union.AsStablehloSliceOptions()) { - std::vector shape = { - static_cast(op->start_indices.size())}; attributes.emplace_back(builder.getNamedAttr( - "start_indices", BuildI64ArrayAttr(shape, op->start_indices, builder))); + "start_indices", builder.getDenseI64ArrayAttr(op->start_indices))); attributes.emplace_back(builder.getNamedAttr( - "limit_indices", BuildI64ArrayAttr(shape, op->limit_indices, builder))); + "limit_indices", builder.getDenseI64ArrayAttr(op->limit_indices))); attributes.emplace_back(builder.getNamedAttr( - "strides", BuildI64ArrayAttr(shape, op->strides, builder))); + "strides", builder.getDenseI64ArrayAttr(op->strides))); return; } if (const auto* op = op_union.AsStablehloConvolutionOptions()) { if (!(op->window_strides.empty())) { - std::vector shape; - shape.push_back(static_cast(op->window_strides.size())); attributes.emplace_back(builder.getNamedAttr( - "window_strides", - BuildRankedTensorAttr(shape, op->window_strides, builder))); + "window_strides", builder.getDenseI64ArrayAttr(op->window_strides))); } if (!(op->padding.empty())) { std::vector shape; @@ -431,25 +416,19 @@ void BuiltinOptions2ToAttributesManual( "padding", BuildRankedTensorAttr(shape, op->padding, builder))); } if (!(op->lhs_dilation.empty())) { - std::vector shape; - shape.push_back(static_cast(op->lhs_dilation.size())); attributes.emplace_back(builder.getNamedAttr( - "lhs_dilation", - BuildRankedTensorAttr(shape, op->lhs_dilation, builder))); + "lhs_dilation", builder.getDenseI64ArrayAttr(op->lhs_dilation))); } if (!(op->rhs_dilation.empty())) { - std::vector shape; - shape.push_back(static_cast(op->rhs_dilation.size())); attributes.emplace_back(builder.getNamedAttr( - "rhs_dilation", - BuildRankedTensorAttr(shape, op->rhs_dilation, builder))); + "rhs_dilation", builder.getDenseI64ArrayAttr(op->rhs_dilation))); } - if (!(op->window_reversal.empty())) + if (!(op->window_reversal.empty())) { + llvm::SmallVector window_reversal; + for (bool b : op->window_reversal) window_reversal.push_back(b); attributes.emplace_back(builder.getNamedAttr( - "window_reversal", - BuildRankedTensorAttr( - {static_cast(op->window_reversal.size())}, - op->window_reversal, builder))); + "window_reversal", builder.getDenseBoolArrayAttr(window_reversal))); + } attributes.emplace_back(builder.getNamedAttr( "dimension_numbers", mlir::stablehlo::ConvDimensionNumbersAttr::get( @@ -502,20 +481,18 @@ void BuiltinOptions2ToAttributesManual( static_cast(op->edge_padding_low.size())}; attributes.emplace_back(builder.getNamedAttr( "edge_padding_low", - BuildI64ArrayAttr(shape, op->edge_padding_low, builder))); + builder.getDenseI64ArrayAttr(op->edge_padding_low))); attributes.emplace_back(builder.getNamedAttr( "edge_padding_high", - BuildI64ArrayAttr(shape, op->edge_padding_high, builder))); + builder.getDenseI64ArrayAttr(op->edge_padding_high))); attributes.emplace_back(builder.getNamedAttr( "interior_padding", - BuildI64ArrayAttr(shape, op->interior_padding, builder))); + builder.getDenseI64ArrayAttr(op->interior_padding))); return; } if (const auto* op = op_union.AsStablehloDynamicSliceOptions()) { attributes.emplace_back(builder.getNamedAttr( - "slice_sizes", - BuildI64ArrayAttr({static_cast(op->slice_sizes.size())}, - op->slice_sizes, builder))); + "slice_sizes", builder.getDenseI64ArrayAttr(op->slice_sizes))); return; } if (const auto* op = op_union.AsStablehloCompareOptions()) { @@ -540,39 +517,27 @@ void BuiltinOptions2ToAttributesManual( } if (const auto* op = op_union.AsStablehloReduceOptions()) { attributes.emplace_back(builder.getNamedAttr( - "dimensions", - BuildRankedTensorAttr({static_cast(op->dimensions.size())}, - op->dimensions, builder))); + "dimensions", builder.getDenseI64ArrayAttr(op->dimensions))); return; } if (const auto* op = op_union.AsStablehloReduceWindowOptions()) { if (!op->window_dimensions.empty()) { attributes.emplace_back(builder.getNamedAttr( "window_dimensions", - BuildRankedTensorAttr( - {static_cast(op->window_dimensions.size())}, - op->window_dimensions, builder))); + builder.getDenseI64ArrayAttr(op->window_dimensions))); } if (!op->window_strides.empty()) { attributes.emplace_back(builder.getNamedAttr( - "window_strides", - BuildRankedTensorAttr( - {static_cast(op->window_strides.size())}, - op->window_strides, builder))); + "window_strides", builder.getDenseI64ArrayAttr(op->window_strides))); } if (!op->base_dilations.empty()) { attributes.emplace_back(builder.getNamedAttr( - "base_dilations", - BuildRankedTensorAttr( - {static_cast(op->base_dilations.size())}, - op->base_dilations, builder))); + "base_dilations", builder.getDenseI64ArrayAttr(op->base_dilations))); } if (!op->window_dilations.empty()) { attributes.emplace_back(builder.getNamedAttr( "window_dilations", - BuildRankedTensorAttr( - {static_cast(op->window_dilations.size())}, - op->window_dilations, builder))); + builder.getDenseI64ArrayAttr(op->window_dilations))); } if (!op->padding.empty()) { attributes.emplace_back(builder.getNamedAttr( @@ -617,9 +582,7 @@ void BuiltinOptions2ToAttributesManual( builder.getNamedAttr("dimension_numbers", gather_dim)); if (!op->slice_sizes.empty()) { attributes.emplace_back(builder.getNamedAttr( - "slice_sizes", - BuildRankedTensorAttr({static_cast(op->slice_sizes.size())}, - op->slice_sizes, builder))); + "slice_sizes", builder.getDenseI64ArrayAttr(op->slice_sizes))); } attributes.emplace_back(builder.getNamedAttr( "indices_are_sorted", BuildBoolAttr(op->indices_are_sorted, builder))); @@ -628,9 +591,7 @@ void BuiltinOptions2ToAttributesManual( if (const auto* op = op_union.AsStablehloTransposeOptions()) { if (!op->permutation.empty()) { attributes.emplace_back(builder.getNamedAttr( - "permutation", - BuildI64ArrayAttr({static_cast(op->permutation.size())}, - op->permutation, builder))); + "permutation", builder.getDenseI64ArrayAttr(op->permutation))); } return; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 45a1e3b25e1335..e8f9787947d6eb 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1061,7 +1061,7 @@ def TFL_DepthwiseConv2DOp : def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ Pure, AccumulatorUniformScale<2, 0, 1>, AffineQuantizedOpInterface, - AffineOpCoefficient<-1, 1>, + AffineOpCoefficient<0, 1>, TFL_SparseOp, DeclareOpInterfaceMethods, QuantizableResult, @@ -1097,7 +1097,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let extraClassDeclaration = [{ // AffineQuantizedOpInterface: int GetChannelDimIndex() { return 0; } - int GetQuantizationDimIndex() { return -1; } + int GetQuantizationDimIndex() { return 0; } // SparseOpInterface: std::vector GetSparseOperands() { return {1}; } std::vector> GetFloatBlockSize() { return {{1, 4}}; } @@ -1219,12 +1219,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I1, I8, I16, I64, I32, UI8, TFL_Str]>:$params, + TFL_TensorOf<[F32, I1, I8, I16, I64, I32, UI8, QI8, TFL_Str]>:$params, TFL_TensorOf<[I16, I32, I64]>:$indices ); let results = (outs - TFL_TensorOf<[F32, I1, I8, I16, I64, I32, UI8, TFL_Str]>:$output + TFL_TensorOf<[F32, I1, I8, I16, I64, I32, UI8, QI8, TFL_Str]>:$output ); } @@ -5399,12 +5399,12 @@ subsequent operation and then be optimized away, however.) }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex>]>:$input, + TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, UI32, QUI8, I16, QI16, I64, Complex>]>:$input, TFL_I32OrI64Tensor:$shape ); let results = (outs - TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex>]>:$output + TFL_TensorOf<[F32, I32, I1, TFL_I4, I8, QI8, UI8, UI32, QUI8, I16, QI16, I64, Complex>]>:$output ); let hasCanonicalizer = 1; diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index a911d438d20368..0ac8bc0ff65117 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -23,21 +23,25 @@ cc_library( srcs = ["tf_tfl_flatbuffer_helpers.cc"], hdrs = ["tf_tfl_flatbuffer_helpers.h"], deps = [ + "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/tools/optimize:reduced_precision_support", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -56,10 +60,8 @@ cc_library( deps = [ ":tf_tfl_flatbuffer_helpers", "//tensorflow/compiler/mlir/lite:common", - "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tf_tfl_passes", - "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", @@ -68,12 +70,8 @@ cc_library( "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/tools/optimize:reduced_precision_support", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", + "@com_google_absl//absl/status", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", ], ) @@ -90,6 +88,7 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -114,23 +113,19 @@ cc_library( ":tf_tfl_flatbuffer_helpers", "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tf_tfl_passes", - "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:types_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", "@local_xla//xla/service:hlo_parser", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 9409856cd4c864..f678daf32f234c 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -16,49 +16,41 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" #include -#include #include #include #include -#include "llvm/Support/ToolOutputFile.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "absl/status/status.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" -#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/lite/tools/optimize/reduced_precision_support.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { -Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - const GraphDebugInfo& debug_info, - const GraphDef& input, - string* result) { + +absl::Status ConvertGraphDefToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + const GraphDebugInfo& debug_info, const GraphDef& input, + std::string* result) { using ::tflite::optimize::ReducedPrecisionSupport; mlir::MLIRContext context; GraphImportConfig specs; mlir::quant::QuantizationSpecs quant_specs; // Parse input arrays. - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; std::vector> node_mins; std::vector> node_maxs; @@ -68,21 +60,20 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes, &node_shapes, &node_mins, &node_maxs)); - TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo( - node_names, node_dtypes, node_shapes, &specs.inputs)); + TF_RETURN_IF_ERROR( + ParseInputArrayInfo(node_names, node_dtypes, node_shapes, &specs.inputs)); // Parse output arrays. - std::vector output_arrays(model_flags.output_arrays().begin(), - model_flags.output_arrays().end()); - TF_RETURN_IF_ERROR( - tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs)); + std::vector output_arrays(model_flags.output_arrays().begin(), + model_flags.output_arrays().end()); + TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs)); // Parse control output arrays. - std::vector control_output_arrays( + std::vector control_output_arrays( model_flags.control_output_arrays().begin(), model_flags.control_output_arrays().end()); - TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(control_output_arrays, - &specs.control_outputs)); + TF_RETURN_IF_ERROR( + ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs)); specs.prune_unused_nodes = true; specs.convert_legacy_fed_inputs = true; @@ -118,10 +109,12 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, toco_flags.guarantee_all_funcs_one_use(); pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); + // StableHLO Quantizer is not supported for GraphDef inputs, so + // quantization_py_function_lib is set to nullptr. return internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, - /*saved_model_tags=*/{}, result, - /*session=*/std::nullopt); + /*saved_model_tags=*/{}, result, /*saved_model_bundle=*/nullptr, + /*quantization_py_function_lib=*/nullptr); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h index e69d3c718d9b37..54f8a996e8883c 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h @@ -15,9 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ +#include + +#include "absl/status/status.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -26,10 +28,10 @@ namespace tensorflow { // Converts the given GraphDef to a TF Lite FlatBuffer string according to the // given model flags, toco flags and debug information. Returns error status if // it fails to convert the input. -Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - const GraphDebugInfo& debug_info, - const GraphDef& input, string* result); +absl::Status ConvertGraphDefToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + const GraphDebugInfo& debug_info, const GraphDef& input, + std::string* result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index f81b5e8b5da6a7..b25040827ccbae 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -20,27 +20,23 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_join.h" -#include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" -#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_parser.h" #include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" @@ -49,23 +45,24 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace tensorflow { namespace { // Error collector that simply ignores errors reported. -class NoOpErrorCollector : public tensorflow::protobuf::io::ErrorCollector { +class NoOpErrorCollector : public protobuf::io::ErrorCollector { public: - void AddError(int line, int column, const string& message) override {} + void AddError(int line, int column, const std::string& message) override {} }; bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) { - tensorflow::protobuf::TextFormat::Parser parser; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + tsl::protobuf::TextFormat::Parser parser; NoOpErrorCollector collector; parser.RecordErrorsTo(&collector); return hlo_proto->ParseFromString(contents) || @@ -75,10 +72,10 @@ bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) { } mlir::OwningOpRef HloToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, + mlir::StringRef input, mlir::MLIRContext* context, bool import_all_computations) { xla::HloProto hlo_proto; - string content(input.data(), input.size()); + std::string content(input.data(), input.size()); if (!LoadHloProto(content, &hlo_proto)) { LOG(ERROR) << "Failed to load proto"; return nullptr; @@ -100,7 +97,7 @@ mlir::OwningOpRef HloTextToMlirHloTranslateFunction( llvm::StringRef input, mlir::MLIRContext* context, bool import_all_computations) { xla::HloProto hlo_proto; - string content(input.data(), input.size()); + std::string content(input.data(), input.size()); auto hlo_module_error = xla::ParseAndReturnUnverifiedModule(content); if (!hlo_module_error.ok()) { @@ -122,16 +119,16 @@ mlir::OwningOpRef HloTextToMlirHloTranslateFunction( } } // namespace -Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, - const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - string* result) { +absl::Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, + const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags, + std::string* result) { mlir::MLIRContext context; mlir::quant::QuantizationSpecs quant_specs; // Parse input arrays. - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; std::vector> node_mins; std::vector> node_maxs; @@ -191,10 +188,12 @@ Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, // phase. main_func->setAttr("tf.entry_function", builder.getDictionaryAttr(attrs)); + // StableHLO Quantizer is not supported for JAX input models, so + // quantization_py_function_lib is set to nullptr. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, - /*saved_model_tags=*/{}, result, - /*session=*/std::nullopt); + /*saved_model_tags=*/{}, result, /*saved_model_bundle=*/nullptr, + /*quantization_py_function_lib=*/nullptr); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 3bd7b947e61021..57550c9f5b0f9d 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" @@ -48,6 +49,8 @@ limitations under the License. namespace tensorflow { +using tensorflow::quantization::PyFunctionLibrary; + Status HandleInputOutputArraysWithModule( const toco::ModelFlags& model_flags, mlir::OwningOpRef* module) { @@ -124,9 +127,10 @@ Status HandleInputOutputArraysWithModule( return OkStatus(); } -Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - string* result) { +Status ConvertSavedModelToTFLiteFlatBuffer( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + std::string* result, + const PyFunctionLibrary* quantization_py_function_lib) { mlir::MLIRContext context; mlir::quant::QuantizationSpecs quant_specs; @@ -199,15 +203,19 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); pass_config.legalize_custom_tensor_list_ops = toco_flags.legalize_custom_tensor_list_ops(); + pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config(); if (toco_flags.qdq_conversion_mode() == "STATIC") { - pass_config.qdq_conversion_mode = + pass_config.quant_specs.qdq_conversion_mode = mlir::quant::QDQConversionMode::kQDQStatic; } else if (toco_flags.qdq_conversion_mode() == "DYNAMIC") { - pass_config.qdq_conversion_mode = + pass_config.quant_specs.qdq_conversion_mode = mlir::quant::QDQConversionMode::kQDQDynamic; + // Need to set this or else the ops will still use floating point kernels + pass_config.quant_specs.inference_type = tensorflow::DT_QINT8; } else if (toco_flags.qdq_conversion_mode() == "NONE") { - pass_config.qdq_conversion_mode = mlir::quant::QDQConversionMode::kQDQNone; + pass_config.quant_specs.qdq_conversion_mode = + mlir::quant::QDQConversionMode::kQDQNone; } else { return errors::InvalidArgument("Unknown QDQ conversion mode: ", toco_flags.qdq_conversion_mode()); @@ -225,7 +233,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, tags, result, - bundle ? bundle->GetSession() : nullptr); + bundle.get(), quantization_py_function_lib); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index 362e9e39ae54c8..50d61dbd4f873b 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,7 +29,8 @@ namespace tensorflow { // status if it fails to convert the input. Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, - string* result); + string* result, + const quantization::PyFunctionLibrary* quantization_py_function_lib); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 980d74d6a47aa5..795485328d4779 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -15,48 +15,55 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" #include -#include #include #include #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/Support/ToolOutputFile.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" -using tsl::StatusOr; - namespace tensorflow { namespace internal { namespace { using ::mlir::quant::ReducedPrecisionSupport; +using ::tensorflow::quantization::PyFunctionLibrary; // Op def string for TFLite_Detection_PostProcess Op. -const char kDetectionPostProcessOp[] = +constexpr mlir::StringRef kDetectionPostProcessOp = "name: 'TFLite_Detection_PostProcess' input_arg: { name: " "'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: " "'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: " @@ -74,7 +81,7 @@ const char kDetectionPostProcessOp[] = "'detections_per_class' type: 'int' default_value { i : 100 }} attr { " "name: 'use_regular_nms' type: 'bool' default_value { b : false }}"; -const char kUnidirectionalSequenceLstmOp[] = +constexpr mlir::StringRef kUnidirectionalSequenceLstmOp = "name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: " "DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } " "input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { " @@ -98,7 +105,7 @@ const char kUnidirectionalSequenceLstmOp[] = "'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} " "attr : { name: '_tflite_input_indices' type: 'list(int)'}"; -const char kUnidirectionalSequenceRnnOp[] = +constexpr mlir::StringRef kUnidirectionalSequenceRnnOp = "name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: " "DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } " "input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { " @@ -158,8 +165,9 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { } } -StatusOr> InputStatsToMinMax(double mean, double std, - DataType type) { +absl::StatusOr> InputStatsToMinMax(double mean, + double std, + DataType type) { // Only qint8 and quint8 are considered here. double qmin, qmax; if (type == DT_QUINT8) { @@ -169,58 +177,59 @@ StatusOr> InputStatsToMinMax(double mean, double std, qmin = -128.0; qmax = 127.0; } else { - return errors::InvalidArgument("Only int8 and uint8 are considered."); + return absl::InvalidArgumentError("Only int8 and uint8 are considered."); } return std::make_pair((qmin - mean) / std, (qmax - mean) / std); } -Status RegisterCustomBuiltinOps(const std::vector extra_tf_opdefs) { +absl::Status RegisterCustomBuiltinOps( + const std::vector extra_tf_opdefs) { for (const auto& tf_opdefs_string : extra_tf_opdefs) { - tensorflow::OpDef opdef; - if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, - &opdef)) { + OpDef opdef; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + if (!tsl::protobuf::TextFormat::ParseFromString(tf_opdefs_string, &opdef)) { return errors::InvalidArgument("fail to parse extra OpDef"); } // Make sure the op is not already registered. If registered continue. const OpRegistrationData* op_reg = - tensorflow::OpRegistry::Global()->LookUp(opdef.name()); + OpRegistry::Global()->LookUp(opdef.name()); if (op_reg) continue; - tensorflow::OpRegistry::Global()->Register( - [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { - *op_reg_data = tensorflow::OpRegistrationData(opdef); - return OkStatus(); + OpRegistry::Global()->Register( + [opdef](OpRegistrationData* op_reg_data) -> absl::Status { + *op_reg_data = OpRegistrationData(opdef); + return absl::OkStatus(); }); } - return OkStatus(); + return absl::OkStatus(); } } // namespace -Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { +absl::Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { // Register any custom OpDefs. - std::vector extra_tf_opdefs(toco_flags.custom_opdefs().begin(), - toco_flags.custom_opdefs().end()); - extra_tf_opdefs.push_back(kDetectionPostProcessOp); - extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp); - extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp); + std::vector extra_tf_opdefs(toco_flags.custom_opdefs().begin(), + toco_flags.custom_opdefs().end()); + extra_tf_opdefs.push_back(kDetectionPostProcessOp.str()); + extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp.str()); + extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp.str()); return RegisterCustomBuiltinOps(extra_tf_opdefs); } -Status PopulateQuantizationSpecs( +absl::Status PopulateQuantizationSpecs( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, mlir::quant::QuantizationSpecs* quant_specs, - std::vector* node_names, std::vector* node_dtypes, + std::vector* node_names, std::vector* node_dtypes, std::vector>>* node_shapes, std::vector>* node_mins, std::vector>* node_maxs) { quant_specs->inference_input_type = ConvertIODataTypeToDataType(toco_flags.inference_input_type()); - tensorflow::DataType inference_type = + DataType inference_type = ConvertIODataTypeToDataType(toco_flags.inference_type()); // Use non-float flag `inference_input_type` to override the `inference_type` // because we have to apply quantization to satisfy that. - if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) { + if (quant_specs->inference_input_type != DT_FLOAT) { inference_type = quant_specs->inference_input_type; } @@ -270,11 +279,11 @@ Status PopulateQuantizationSpecs( quant_specs->disable_per_channel = toco_flags.disable_per_channel_quantization(); if (toco_flags.quantize_to_float16()) { - quant_specs->inference_type = tensorflow::DT_HALF; - quant_specs->inference_input_type = tensorflow::DT_HALF; + quant_specs->inference_type = DT_HALF; + quant_specs->inference_input_type = DT_HALF; } else { - quant_specs->inference_type = tensorflow::DT_QINT8; - quant_specs->inference_input_type = tensorflow::DT_QINT8; + quant_specs->inference_type = DT_QINT8; + quant_specs->inference_input_type = DT_QINT8; } } else { // These flags are incompatible with post_training_quantize() as only @@ -313,11 +322,14 @@ Status PopulateQuantizationSpecs( toco_flags.enable_mlir_dynamic_range_quantizer(); quant_specs->enable_mlir_variable_quantization = toco_flags.enable_mlir_variable_quantization(); - return OkStatus(); + quant_specs->disable_per_channel_for_dense_layers = + toco_flags.disable_per_channel_quantization_for_dense_layers(); + return absl::OkStatus(); } // Dumps the op graph of the `module` to `filename` in DOT format. -Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { +absl::Status DumpOpGraphToFile(mlir::ModuleOp module, + const std::string& filename) { std::string error_message; auto output = mlir::openOutputFile(filename, &error_message); if (!error_message.empty()) { @@ -329,15 +341,16 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { return errors::Unknown("Failed to dump Op Graph from MLIR module."); } output->keep(); - return OkStatus(); + return absl::OkStatus(); } -Status ConvertMLIRToTFLiteFlatBuffer( +absl::Status ConvertMLIRToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, - const std::unordered_set& saved_model_tags, string* result, - std::optional session) { + const std::unordered_set& saved_model_tags, + std::string* result, SavedModelBundle* saved_model_bundle, + const PyFunctionLibrary* quantization_py_function_lib) { if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( module.get(), @@ -361,7 +374,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy, - saved_model_tags, model_flags.saved_model_dir(), session, result); + saved_model_tags, model_flags.saved_model_dir(), saved_model_bundle, + result, /*serialize_stablehlo_ops=*/false, quantization_py_function_lib); if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 8d7eeb2912a7b6..039e56672ddadc 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -24,8 +24,10 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -55,7 +57,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, string* result, - std::optional session); + SavedModelBundle* saved_model_bundle, + const quantization::PyFunctionLibrary* quantization_py_function_lib); // Give a warning for any unused flags that have been specified. void WarningUnusedFlags(const toco::ModelFlags& model_flags, diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 356801e7fd38a0..cf437c27e2ec4e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -2,7 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], default_visibility = [ ":friends", "//tensorflow:__pkg__", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 4545aa412686b7..f7ffc7d71d02f9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -58,7 +58,8 @@ TfLiteStatus QuantizeModel( bool whole_model_verify, bool legacy_float_scale, const absl::flat_hash_set& denylisted_ops, const absl::flat_hash_set& denylisted_nodes, - const bool enable_variable_quantization) { + const bool enable_variable_quantization, + bool disable_per_channel_for_dense_layers) { // Translate TFLite names to mlir op names. absl::flat_hash_set denylisted_mlir_op_names; for (const auto& entry : denylisted_ops) { @@ -84,6 +85,8 @@ TfLiteStatus QuantizeModel( quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; quant_specs.disable_per_channel = disable_per_channel; + quant_specs.disable_per_channel_for_dense_layers = + disable_per_channel_for_dense_layers; quant_specs.verify_numeric = verify_numeric; quant_specs.whole_model_verify = whole_model_verify; quant_specs.legacy_float_scale = legacy_float_scale; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index d85aba47811675..50b397ba0206d2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -54,7 +54,8 @@ TfLiteStatus QuantizeModel( bool whole_model_verify = false, bool legacy_float_scale = true, const absl::flat_hash_set& denylisted_ops = {}, const absl::flat_hash_set& denylisted_nodes = {}, - bool enable_variable_quantization = false); + bool enable_variable_quantization = false, + bool disable_per_channel_for_dense_layers = false); } // namespace lite } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 5898c9e54234a5..696a2545d7097a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -72,7 +72,8 @@ TfLiteStatus QuantizeModel( const TensorType& activations_type, ErrorReporter* error_reporter, std::string& output_buffer, const bool disable_per_channel = false, const absl::flat_hash_set& blocked_ops = {}, - const absl::flat_hash_set& blocked_nodes = {}) { + const absl::flat_hash_set& blocked_nodes = {}, + const bool disable_per_channel_for_dense_layers = false) { TensorType inference_tensor_type = activations_type; const bool fully_quantize = !allow_float; @@ -87,7 +88,10 @@ TfLiteStatus QuantizeModel( input_buffer, input_type, output_type, inference_tensor_type, /*operator_names=*/{}, disable_per_channel, fully_quantize, output_buffer, error_reporter, /*verify_numeric=*/false, /*whole_model_verify=*/false, - /*legacy_float_scale=*/true, blocked_ops, blocked_nodes); + /*legacy_float_scale=*/true, blocked_ops, blocked_nodes, + /*enable_variable_quantization=*/false, + /*disable_per_channel_for_dense_layers=*/ + disable_per_channel_for_dense_layers); if (status != kTfLiteOk) { return status; } @@ -140,6 +144,21 @@ TfLiteStatus QuantizeModelAllOperators( output_buffer); } +TfLiteStatus QuantizeModelAllOperators( + ModelT* model, const TensorType& input_type, const TensorType& output_type, + bool allow_float, const TensorType& activations_type, + ErrorReporter* error_reporter, std::string& output_buffer, + bool disable_per_channel_for_dense_layers) { + return QuantizeModel(model, input_type, output_type, allow_float, + /*operator_names=*/{}, activations_type, error_reporter, + output_buffer, + /*disable_per_channel=*/false, + /* blocked_ops=*/{}, + /*blocked_nodes=*/{}, + /*disable_per_channel_for_dense_layers=*/ + disable_per_channel_for_dense_layers); +} + std::unique_ptr ReadModel(const string& model_name) { auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name); return FlatBufferModel::BuildFromFile(model_path.c_str()); @@ -1118,16 +1137,20 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { ExpectSameModels(model_, expected_model); } -class QuantizeFCTest : public QuantizeModelTest { +class QuantizeFCTest : public QuantizeModelTest, + public testing::WithParamInterface { protected: QuantizeFCTest() { + disable_per_channel_quantization_for_dense_ = GetParam(); input_model_ = ReadModel(internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } + + bool disable_per_channel_quantization_for_dense_; }; -TEST_F(QuantizeFCTest, VerifyFC8x8) { +TEST_P(QuantizeFCTest, VerifyFC8x8) { auto status = QuantizeModelAllOperators( &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, TensorType_INT8, &error_reporter_, output_buffer_); @@ -1180,7 +1203,7 @@ TEST_F(QuantizeFCTest, VerifyFC8x8) { /*bit_num=*/8, /*symmetric=*/false); } -TEST_F(QuantizeFCTest, VerifyFCFor16x8) { +TEST_P(QuantizeFCTest, VerifyFCFor16x8) { auto status = QuantizeModelAllOperators( &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, TensorType_INT16, &error_reporter_, output_buffer_); @@ -1195,7 +1218,7 @@ TEST_F(QuantizeFCTest, VerifyFCFor16x8) { ASSERT_THAT(op->outputs, SizeIs(1)); const SubGraph* float_graph = readonly_model_->subgraphs()->Get(0); - // Verify FC input tesnor and weight are int16 and int8 quantized. + // Verify FC input tensor and weight are int16 and int8 quantized. const Operator* float_op = float_graph->operators()->Get(0); ASSERT_THAT(float_graph->tensors()->Get(float_op->inputs()->Get(0))->type(), Eq(TensorType_FLOAT32)); @@ -1235,6 +1258,136 @@ TEST_F(QuantizeFCTest, VerifyFCFor16x8) { /*bit_num=*/16, /*symmetric=*/true); } +TEST_P(QuantizeFCTest, VerifyDisablePerChannelQuantization) { + auto status = QuantizeModelAllOperators( + &model_, TensorType_INT8, TensorType_INT8, /*allow_float=*/false, + TensorType_INT8, &error_reporter_, output_buffer_, + /*disable_per_channel_for_dense_layers=*/ + disable_per_channel_quantization_for_dense_); + ASSERT_THAT(status, Eq(kTfLiteOk)); + const auto& subgraph = model_.subgraphs[0]; + auto fc_op = subgraph->operators[0].get(); + + ASSERT_THAT(fc_op->inputs, SizeIs(3)); + ASSERT_THAT(fc_op->outputs, SizeIs(1)); + + const int input_tensor_idx = 0; + const int weights_tensor_idx = 1; + const int bias_tensor_index = 2; + const int output_tensor_idx = 0; + const auto bias_tensor = + subgraph->tensors[fc_op->inputs[bias_tensor_index]].get(); + const auto input_tensor = + subgraph->tensors[fc_op->inputs[input_tensor_idx]].get(); + const auto weights_tensor = + subgraph->tensors[fc_op->inputs[weights_tensor_idx]].get(); + const auto output_tensor = + subgraph->tensors[fc_op->outputs[output_tensor_idx]].get(); + + EXPECT_THAT(bias_tensor->type, Eq(TensorType_INT32)); + EXPECT_THAT(input_tensor->type, Eq(TensorType_INT8)); + EXPECT_THAT(weights_tensor->type, Eq(TensorType_INT8)); + EXPECT_THAT(output_tensor->type, Eq(TensorType_INT8)); + + ASSERT_TRUE(weights_tensor->quantization); + ASSERT_TRUE(bias_tensor->quantization); + ASSERT_TRUE(weights_tensor->quantization); + const std::vector& bias_scales = bias_tensor->quantization->scale; + const std::vector& weights_scales = + weights_tensor->quantization->scale; + const std::vector& weights_zero_points = + weights_tensor->quantization->zero_point; + + const int out_channel_size = 2; + ASSERT_THAT(bias_scales, SizeIs(disable_per_channel_quantization_for_dense_ + ? 1 + : out_channel_size)); + ASSERT_THAT(weights_scales, SizeIs(disable_per_channel_quantization_for_dense_ + ? 1 + : out_channel_size)); + ASSERT_THAT( + weights_zero_points, + SizeIs(disable_per_channel_quantization_for_dense_ ? 1 + : out_channel_size)); + ASSERT_THAT(input_tensor->quantization->scale, SizeIs(1)); + ASSERT_THAT(output_tensor->quantization->scale, SizeIs(1)); + + const float eps = 1e-7; + + // Bias scale should be input * per_channel_weight_scale. + for (size_t i = 0; i < out_channel_size; i++) { + EXPECT_THAT((disable_per_channel_quantization_for_dense_ ? bias_scales[0] + : bias_scales[i]), + FloatNear(input_tensor->quantization->scale[0] * + (disable_per_channel_quantization_for_dense_ + ? weights_scales[0] + : weights_scales[i]), + eps)); + } + + const auto bias_buffer = model_.buffers[bias_tensor->buffer].get(); + auto control_size = sizeof(int32_t) * bias_tensor->shape[0]; + + ASSERT_THAT(bias_buffer->data, SizeIs(control_size)); + const auto float_op = + readonly_model_->subgraphs()->Get(0)->operators()->Get(0); + const auto original_bias_tensor = + readonly_model_->subgraphs()->Get(0)->tensors()->Get( + float_op->inputs()->Get(2)); + ASSERT_THAT(bias_buffer->data, SizeIs(control_size)); + const auto original_bias_buffer = + readonly_model_->buffers()->Get(original_bias_tensor->buffer()); + const float* bias_float_buffer = + reinterpret_cast(original_bias_buffer->data()->data()); + + int32_t* bias_values = reinterpret_cast(bias_buffer->data.data()); + for (size_t i = 0; i < out_channel_size; i++) { + const float bias_scale = disable_per_channel_quantization_for_dense_ + ? bias_scales[0] + : bias_scales[i]; + auto dequantized_value = bias_values[i] * bias_scale; + EXPECT_THAT(dequantized_value, + FloatNear(bias_float_buffer[i], bias_scale / 2)); + } + + const auto weights_buffer = model_.buffers[weights_tensor->buffer].get(); + const auto original_weights_tensor = + readonly_model_->subgraphs()->Get(0)->tensors()->Get( + float_op->inputs()->Get(1)); + const auto original_weights_buffer = + readonly_model_->buffers()->Get(original_weights_tensor->buffer()); + const int8_t* weight_values = + reinterpret_cast(weights_buffer->data.data()); + const float* weights_float_buffer = + reinterpret_cast(original_weights_buffer->data()->data()); + ASSERT_THAT(sizeof(float) * weights_buffer->data.size(), + Eq(original_weights_buffer->data()->size())); + int num_values_in_channel = weights_buffer->data.size() / out_channel_size; + for (size_t channel_idx = 0; channel_idx < out_channel_size; channel_idx++) { + for (size_t j = 0; j < num_values_in_channel; j++) { + size_t element_idx = channel_idx * num_values_in_channel + j; + auto scale = disable_per_channel_quantization_for_dense_ + ? weights_scales[0] + : weights_scales[channel_idx]; + auto zero_point = disable_per_channel_quantization_for_dense_ + ? weights_zero_points[0] + : weights_zero_points[channel_idx]; + auto dequantized_value = weight_values[element_idx] * scale; + EXPECT_THAT(dequantized_value, + FloatNear(weights_float_buffer[element_idx], scale / 2)); + EXPECT_THAT(zero_point, Eq(0)); + } + } + + // check op and versioning. + EXPECT_THAT(model_.operator_codes, SizeIs(1)); + EXPECT_THAT(GetBuiltinCode(model_.operator_codes[0].get()), + Eq(BuiltinOperator_FULLY_CONNECTED)); + ASSERT_THAT(model_.operator_codes[0]->version, 5); +} + +INSTANTIATE_TEST_SUITE_P(QuantizeFCTestInst, QuantizeFCTest, testing::Bool()); + class QuantizeCustomOpTest : public QuantizeModelTest, public ::testing::WithParamInterface { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index fcf36de6247f02..66175aabf394e4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -45,6 +45,7 @@ struct CustomOpInfo { using ::tflite::optimize::ReducedPrecisionSupport; using CustomOpMap = std::unordered_map; enum CustomOpUpdateOptions { kInputIndices, kWeightOnly, kNoSideEffect }; +enum class QDQConversionMode { kQDQNone, kQDQStatic, kQDQDynamic }; struct QuantizationSpecs { // Which function this node quant specifications belong to. @@ -85,6 +86,11 @@ struct QuantizationSpecs { // weight FakeQuant). bool disable_per_channel = false; + // Disables per channel weights quantization for Dense layers and enables + // legacy per tensor quantization. The legacy quantization for Dense layers is + // inconsistent with Conv 1x1 which always performs per channel quantization. + bool disable_per_channel_for_dense_layers = false; + // When set to true, the fixed output ranges of the activation ops (tanh, // sigmoid, etc.) and the weight constants are not inferred. Then, to quantize // these ops, quantization emulation ops should be placed after the ops in the @@ -209,9 +215,11 @@ struct QuantizationSpecs { // For dynamic range quantization, among the custom ops in the graph those // specified in this map are subject to quantization. CustomOpMap custom_map; -}; -enum class QDQConversionMode { kQDQNone, kQDQStatic, kQDQDynamic }; + // If other than kQDQNone, the model is a floating point graph with QDQ ops + // to be eliminated and fused into quantized kernels. + QDQConversionMode qdq_conversion_mode = QDQConversionMode::kQDQNone; +}; // Parses the command line flag strings to the CustomOpMap specification. void ParseCustomOpSpecs(absl::string_view node_names, diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 62c2733d2b510c..408540cd84a146 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -32,6 +33,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project @@ -109,7 +111,8 @@ class QuantizationDriver { bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, - bool infer_tensor_range, bool legacy_float_scale) + bool infer_tensor_range, bool legacy_float_scale, + bool is_qdq_conversion) : fn_(fn), builder_(fn.getBody()), is_signed_(is_signed), @@ -118,7 +121,8 @@ class QuantizationDriver { op_quant_spec_getter_(op_quant_spec_getter), op_quant_scale_spec_getter_(op_quant_scale_spec_getter), infer_tensor_range_(infer_tensor_range), - legacy_float_scale_(legacy_float_scale) {} + legacy_float_scale_(legacy_float_scale), + is_qdq_conversion_(is_qdq_conversion) {} // The entry point of the quantization parameters propagation. void Run(); @@ -198,7 +202,7 @@ class QuantizationDriver { // Returns the quantization params for the bias input from the non-bias // operands which have their indexes in the `non_biases` vector. The returned // parameters are calculated by `func`. - QuantParams GetBiasParams(Operation *op, int bias, + QuantParams GetBiasParams(Operation *op, int bias_index, const std::vector &non_biases, AccumulatorScaleFunc func); @@ -429,6 +433,10 @@ class QuantizationDriver { // Calculate scales in float instead of double, so that the scales and // quantized values are exactly the same with the TOCO quantizer. bool legacy_float_scale_; + + // If true, the model is a floating point graph with QDQ ops to be eliminated + // and fused into quantized kernels. + bool is_qdq_conversion_; }; } // namespace @@ -518,20 +526,35 @@ bool QuantizationDriver::SetResultParams(Operation *op, int res_index, } QuantParams QuantizationDriver::GetBiasParams( - Operation *op, int bias, const std::vector &non_biases, + Operation *op, const int bias_index, const std::vector &non_biases, AccumulatorScaleFunc func) { - auto &bias_state = GetOperandQuantState(op, bias); + QuantState &bias_state = GetOperandQuantState(op, bias_index); if (!bias_state.IsEmpty()) { return bias_state.params; } std::vector op_types; op_types.reserve(non_biases.size()); + int adjusted_quant_dim = -1; + if (op->getNumOperands() > bias_index) { + // Some kernels allow 1D bias, broadcasting it inside the kernel. In this + // case, the `quantizedDimension=0` when quantizing per-channel. + // However, for some kernels which require bias to be already broadcasted + // to match the accumulation shape, the very last index should be used. + Operation *bias_op = op->getOperand(bias_index).getDefiningOp(); + if (bias_op != nullptr) { + Type bias_type = bias_op->getResult(0).getType(); + if (bias_type != builder_.getNoneType()) { + int bias_rank = bias_type.dyn_cast().getRank(); + adjusted_quant_dim = bias_rank > 1 ? bias_rank - 1 : 0; + } + } + } + for (auto non_bias : non_biases) { auto &non_bias_type = GetOperandQuantState(op, non_bias); op_types.push_back(non_bias_type.params); } - if (op_types.empty()) return {}; - return func(op_types, legacy_float_scale_); + return func(op_types, adjusted_quant_dim, legacy_float_scale_); } bool QuantizationDriver::SetOperandParams(Operation *op, int index, @@ -956,7 +979,10 @@ bool QuantizationDriver::PropagateParams() { } } - if (scale_spec->has_fixed_output_range && infer_tensor_range_) { + // If the model already contains immutable QDQs, require upstream to + // explicitly fix output range instead. + if (scale_spec->has_fixed_output_range && infer_tensor_range_ && + !is_qdq_conversion_) { // Infer ranges from the activation ops. This is usually required for // the post-training quantization workflow. // TODO(fengliuai): different result can have different fixed range. @@ -1182,20 +1208,22 @@ void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale) { + bool legacy_float_scale, + bool is_qdq_conversion) { ApplyQuantizationParamsPropagation( func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, - GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale); + GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale, + is_qdq_conversion); } void ApplyQuantizationParamsPropagation( mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale) { + bool legacy_float_scale, bool is_qdq_conversion) { QuantizationDriver(func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, op_quant_scale_spec_getter, - infer_tensor_ranges, legacy_float_scale) + infer_tensor_ranges, legacy_float_scale, is_qdq_conversion) .Run(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 9a151a80e8f48b..53f8024c7900dd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -29,13 +30,20 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" @@ -43,7 +51,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" -#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/tools/optimize/quantization_utils.h" namespace mlir { @@ -469,7 +477,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim, quant::QuantizedType GetUniformQuantizedTypeForBias( const std::vector& op_types, - bool legacy_float_scale) { + const int adjusted_quant_dim, const bool legacy_float_scale) { if (op_types.empty()) return {}; size_t axis_size = 1; @@ -531,13 +539,14 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( /*zeroPoint=*/0, storage_type_min, storage_type_max); } else { llvm::SmallVector zero_points(axis_size, 0); - // Assume the bias is a 1-D tensor, and set the quantization dim to the last - // dimension, which is 0. If the bias rank is larger than 1, this returned - // quantized type couldn't be used to quantize the bias. + // If the bias is a 1-D tensor, set the `quantizedDimension` to 0. + // If the bias rank is larger than 1 because it was already broadcasted + // to match the output shape, use the last index. return quant::UniformQuantizedPerAxisType::getChecked( builder.getUnknownLoc(), /*flags=*/true, storage_type, expressed_type, scales, zero_points, - /*quantizedDimension=*/0, storage_type_min, storage_type_max); + /*quantizedDimension=*/std::max(adjusted_quant_dim, 0), + storage_type_min, storage_type_max); } } @@ -598,7 +607,7 @@ ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type) { return DenseElementsAttr::get(new_dense_type, quantized_attr); } else if (width == 8) { // This can be a state tensor, or an actual constant tensor with - // asymmetric range. For a state tensor, assigining correct quantization + // asymmetric range. For a state tensor, assigning correct quantization // parameters is sufficient, and for constants with asymmetric range it's // not correctly quantized by legacy quantizer so call the new Quantize. return Quantize(real_value, tensor_type); @@ -643,7 +652,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) { quant::QuantizedType::getQuantizedElementType(tensor_type)) { Type converted_type; return quantfork::quantizeAttr(real_value, q_type, converted_type) - .dyn_cast(); + .dyn_cast_or_null(); } return {}; } @@ -816,7 +825,7 @@ bool RemoveRedundantStatsOps( } } - // Step 2: backward pass: For the ops skiped in the forward pass, propagate + // Step 2: backward pass: For the ops skipped in the forward pass, propagate // its results scale backwards as far as possible. func.walk([&](quantfork::StatisticsOp stats_op) { if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 1113bb868fa3e8..e1b697e3be67d1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -66,29 +66,29 @@ namespace quant { // A unit attribute can be attached to the quantize/dequantize ops which are // added by the quantization passes. These ops can be removed erased without // losing accuracy. -constexpr char kVolatileOpAttrName[] = "volatile"; +inline constexpr char kVolatileOpAttrName[] = "volatile"; // Following attributes are used to mark ops that are not quantizable during // debug model generation process for whole-model verify mode. If these // attributes are attached, the upstream float/quantized ops know which ops to // connect to, and it also prevents these ops from being copied again. -constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; -constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; +inline constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; +inline constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; // Used to annotate custom ops if they are quantizable. -constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; +inline constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 }; -constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", - "not_quantizable"}; +inline constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", + "not_quantizable"}; -constexpr double kNearZeroTolerance = 1.0e-6; +inline constexpr double kNearZeroTolerance = 1.0e-6; using QuantParams = QuantizedType; using QuantSpec = QuantizationSpecs; using SignedInteger = std::pair; // bitwidth and sign using QuantParamsForResults = llvm::SmallVector; using AccumulatorScaleFunc = - std::function&, bool)>; + std::function&, int, bool)>; using BiasParamsMap = std::unordered_map, AccumulatorScaleFunc>>; // UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) @@ -890,7 +890,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( // other operands which are multiply-accumulated (the bias is added to the // accumulated value). quant::QuantizedType GetUniformQuantizedTypeForBias( - const std::vector& op_types, + const std::vector& op_types, int adjusted_quant_dim, bool legacy_float_scale = false); // Propagates quantization parameters across ops in this function and satisfy @@ -906,13 +906,14 @@ void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale = false); + bool legacy_float_scale = false, + bool is_qdq_conversion = false); void ApplyQuantizationParamsPropagation( mlir::func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale = false); + bool legacy_float_scale = false, bool is_qdq_conversion = false); // Gets quantization scale specs (e.g. fixed output range, same result and // operand scales) from the default quantization interfaces. The op should diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD new file mode 100644 index 00000000000000..7f6b74431a95b7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -0,0 +1,48 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/mlir/lite:__subpackages__"], + licenses = ["notice"], +) + +cc_library( + name = "quantization", + srcs = ["quantization.cc"], + hdrs = ["quantization.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "quantization_test", + srcs = ["quantization_test.cc"], + deps = [ + ":quantization", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", # buildcleaner: keep; prevents undefined reference + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", # buildcleaner: keep; prevents undefined reference + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc new file mode 100644 index 00000000000000..929634164fba3b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace { + +using ::mlir::quant::stablehlo::StaticRangePtqComponent; +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::PyFunctionLibrary; + +// Returns signature key -> `SignatureDef` mapping, excluding the signature for +// initialization op, which is only used during initialization. +// TODO: b/314124142 - Remove the need for this function. +absl::flat_hash_map GetSignatureDefMapFromBundle( + const SavedModelBundle& saved_model_bundle) { + // Translate protobuf::Map -> absl::flat_hash_map. + const protobuf::Map& signatures = + saved_model_bundle.GetSignatures(); + absl::flat_hash_map signature_def_map( + signatures.begin(), signatures.end()); + + // Init op is only used during initialization and it's not a target for + // quantization. + signature_def_map.erase(kSavedModelInitOpSignatureKey); + return signature_def_map; +} + +// Retrieves the function name -> function alias mapping from the +// `SavedModelBundle`. +// TODO: b/314124142 - Remove the need for this function. +absl::flat_hash_map GetFunctionAliases( + const SavedModelBundle& saved_model_bundle) { + const protobuf::Map& function_aliases = + saved_model_bundle.meta_graph_def.meta_info_def().function_aliases(); + return absl::flat_hash_map(function_aliases.begin(), + function_aliases.end()); +} + +} // namespace + +absl::StatusOr RunQuantization( + const SavedModelBundle* saved_model_bundle, + const absl::string_view saved_model_dir, + const std::unordered_set& saved_model_tags, + const QuantizationConfig& quantization_config, + const PyFunctionLibrary* quantization_py_function_lib, + mlir::ModuleOp module_op) { + if (saved_model_bundle == nullptr) { + return absl::InvalidArgumentError( + "Failed to run quantization. `saved_model_bundle` should not be " + "nullptr."); + } + + if (quantization_py_function_lib == nullptr) { + return absl::InvalidArgumentError( + "Failed to run quantization. `quantization_py_function_lib` should not " + "be nullptr."); + } + + const absl::flat_hash_map signature_def_map = + GetSignatureDefMapFromBundle(*saved_model_bundle); + + std::vector exported_names; + for (const auto& [key, value_unused] : signature_def_map) { + exported_names.push_back(key); + } + + if (failed(mlir::tf_saved_model::FreezeVariables( + module_op, saved_model_bundle->GetSession()))) { + return absl::InternalError("Failed to freeze variables."); + } + + StaticRangePtqComponent static_range_ptq_component( + module_op.getContext(), quantization_py_function_lib, saved_model_dir, + /*signature_keys=*/exported_names, saved_model_tags, signature_def_map, + GetFunctionAliases(*saved_model_bundle)); + const absl::StatusOr quantized_module_op = + static_range_ptq_component.Run(module_op, quantization_config); + if (!quantized_module_op.ok()) { + return absl::InternalError("Failed to run quantization. Status msg: " + + quantized_module_op.status().ToString()); + } + return quantized_module_op; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h new file mode 100644 index 00000000000000..c55d59cad0f1a0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Adaptor functions for StableHLO Quantizer. +// Provides simpler interfaces when integrating StableHLO Quantizer into TFLite +// Converter. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" + +namespace tensorflow { + +// Runs quantization on `module_op`. `saved_model_bundle` is required to +// retrieve information about the original model (e.g. signature def mapping) +// because quantization requires exporting the intermediate `ModuleOp` back to +// SavedModel for calibration. Similarly, `saved_model_dir` is required to +// access the assets of the original model. `saved_model_tags` uniquely +// identifies the `MetaGraphDef`. `quantization_config` determines the behavior +// of StableHLO Quantizer. `quantization_py_function_lib` contains python +// implementations of certain APIs that are required for calibration. +// `module_op` is the input graph to be quantized and it should contain +// StableHLO ops. +// +// Returns a quantized `ModuleOp` in StableHLO, potentially wrapped inside a +// XlaCallModuleOp. Returns a non-OK status if quantization fails, or any of +// `saved_model_bundle` or `quantization_py_function_lib` is a nullptr. +absl::StatusOr RunQuantization( + const SavedModelBundle* saved_model_bundle, + absl::string_view saved_model_dir, + const std::unordered_set& saved_model_tags, + const stablehlo::quantization::QuantizationConfig& quantization_config, + const tensorflow::quantization::PyFunctionLibrary* + quantization_py_function_lib, + mlir::ModuleOp module_op); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization_test.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization_test.cc new file mode 100644 index 00000000000000..3cbc9e6ea47864 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Test cases for the StableHLO Quantizer adaptor functions. + +#include "tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tsl/platform/status_matchers.h" + +namespace tensorflow { +namespace { + +using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::io::CreateTmpDir; +using ::testing::HasSubstr; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +// Test cases for `RunQuantization` mainly tests for error cases because testing +// for successful cases require passing python implementation to +// `quantization_py_function_lib`, which requires testing from the python level. +// Internal integration tests exist for testing successful quantization. + +TEST(RunQuantizationTest, + WhenSavedModelBundleIsNullptrReturnsInvalidArgumentError) { + const absl::StatusOr tmp_saved_model_dir = CreateTmpDir(); + ASSERT_THAT(tmp_saved_model_dir, IsOk()); + + const absl::StatusOr quantized_module_op = RunQuantization( + /*saved_model_bundle=*/nullptr, *tmp_saved_model_dir, + /*saved_model_tags=*/{}, QuantizationConfig(), + /*quantization_py_function_lib=*/nullptr, /*module_op=*/{}); + EXPECT_THAT( + quantized_module_op, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("`saved_model_bundle` should not be nullptr"))); +} + +TEST(RunQuantizationTest, + WhenPyFunctionLibIsNullptrReturnsInvalidArgumentError) { + const absl::StatusOr tmp_saved_model_dir = CreateTmpDir(); + ASSERT_THAT(tmp_saved_model_dir, IsOk()); + + // Dummy SavedModelBundle to pass a non-nullptr argument. + SavedModelBundle bundle{}; + const absl::StatusOr quantized_module_op = RunQuantization( + /*saved_model_bundle=*/&bundle, *tmp_saved_model_dir, + /*saved_model_tags=*/{}, QuantizationConfig(), + /*quantization_py_function_lib=*/nullptr, /*module_op=*/{}); + EXPECT_THAT( + quantized_module_op, + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("`quantization_py_function_lib` should not be nullptr"))); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index ad4112a05ad4a9..47440b4c4c0beb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -15,7 +15,10 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/PrettyStackTrace.h" @@ -24,6 +27,7 @@ limitations under the License. #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #include "mlir/TableGen/Operator.h" // from @llvm-project +#include "mlir/TableGen/Trait.h" // from @llvm-project using llvm::LessRecord; using llvm::raw_ostream; @@ -50,7 +54,8 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { llvm::sort(defs, LessRecord()); OUT(0) << "static std::unique_ptr " - "GetOpQuantSpec(mlir::Operation *op) {\n"; + "GetOpQuantSpec(mlir::Operation *op, bool " + "disable_per_channel_for_dense_layers = false) {\n"; // TODO(b/176258587): Move to OpTrait if this should be generalized. // Add special handling for LSTM. OUT(2) << "if (auto lstm_op = llvm::dyn_cast(op)) {\n"; @@ -94,7 +99,9 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { // There is a "QuantChannelDim" trait, set the quantization dimension. if (coeff_index_trait_regex.match(trait_str, &matches)) { OUT(4) << "spec->coeff_op_quant_dim[tfl.GetCoefficientOperandIndex()" - << "] = tfl.GetQuantizationDim();\n"; + << "] = llvm::dyn_cast(op) && " + "disable_per_channel_for_dense_layers ? -1 : " + "tfl.GetQuantizationDim();\n"; matches.clear(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index feb47bfe4420b0..89d6139ba7ba19 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -280,7 +280,7 @@ cc_library( deps = [ ":passes_inc_gen", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", - "//tensorflow/compiler/mlir/quantization/stablehlo:uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", "@com_google_absl//absl/algorithm:container", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -407,12 +407,14 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", @@ -455,10 +457,12 @@ cc_library( deps = [ ":passes_inc_gen", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/quantization/stablehlo:uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -599,8 +603,10 @@ tf_cc_binary( "//tensorflow/compiler/mlir/lite:flatbuffer_export", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite/stablehlo/serializer:flatbuffer_export", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir index 795840247cab93..292802ec92e5e1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/call_xla_module_to_stablehlo.mlir @@ -17,7 +17,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr } } -// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} { +// CHECK: module attributes +// CHECK-SAME: tfl.metadata = {{{.*}}keep_stablehlo_constant = "true"{{.*}}} // CHECK-NEXT: func.func @main(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> attributes {tf.entry_function = {inputs = "args_tf_0", outputs = "Identity"}} { // CHECK-NEXT: %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} : (tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK-NEXT: %1 = stablehlo.multiply %0, %0 : tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir index f50c399deb5579..d0da1f09fa5ae1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir @@ -8,10 +8,8 @@ func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { } } -// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { +// CHECK: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { // CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %1 : tensor<2xi32> // CHECK-NEXT: } -// CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir index 3b5dae4706e877..85653de898aa01 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir @@ -2,7 +2,7 @@ module { func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0= "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + %0= "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> func.return %0 : tensor<1x2x2xi32> } } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir index 4030506472268c..aa7742c15e4c42 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir @@ -13,15 +13,15 @@ func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> output_batch_dimension = 3, output_feature_dimension = 0, output_spatial_dimensions = [1, 2] - >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#stablehlo, #stablehlo], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : + >, feature_group_count = 1 : i64, lhs_dilation = array, padding = dense<1> : tensor<2x2xi64>, precision_config = [#stablehlo, #stablehlo], rhs_dilation = array, window_strides = array, window_reversal = array} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> func.return %0 : tensor<16x8x8x1xf32> } } -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> -// CHECK-NEXT: return %0 : tensor<16x8x8x1xf32> -// CHECK-NEXT: } +// CHECK: module { +// CHECK-NEXT: func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> +// CHECK-NEXT: return %0 : tensor<16x8x8x1xf32> +// CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir index 5fb78f0540c3c5..47c716c0ca5243 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir @@ -9,7 +9,7 @@ func.func @main(%arg0: tensor<1x128x256xf32>, %arg1: tensor<30x1x2xi32>) -> tens start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, - slice_sizes = dense<[1, 1, 256]> : tensor<3xi64>} : + slice_sizes = array} : (tensor<1x128x256xf32>, tensor<30x1x2xi32>) -> tensor<30x1x256xf32> func.return %0 : tensor<30x1x256xf32> } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir index 65f988cebe7c70..a8a3c8a18683b5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir @@ -2,13 +2,14 @@ module { func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { - %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> + %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> func.return %0 : tensor<16x8x8x1xf32> } } - -// CHECK: module { -// CHECK: func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { -// CHECK: %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [0, 1, b, f]x[0, 1, o, i]->[f, 0, 1, b], window = {stride = [1, 1], pad = {{\[}}[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> -// CHECK: return %0 : tensor<16x8x8x1xf32> +// CHECK: module { +// CHECK-NEXT: func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { +// CHECK-NEXT: %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [0, 1, b, f]x[0, 1, o, i]->[f, 0, 1, b], window = {stride = [1, 1], pad = {{\[}}[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [true, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> +// CHECK-NEXT: return %0 : tensor<16x8x8x1xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index d9d86dac6782e6..4152a1b785daa6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -8,7 +8,7 @@ // CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> // CHECK: } func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } @@ -19,7 +19,7 @@ func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x32x10x32xi32> // CHECK: } func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } @@ -30,7 +30,7 @@ func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> // CHECK: return %[[VAL_2]] : tensor // CHECK: } func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %0 : tensor } @@ -68,7 +68,7 @@ func.func @broadcast_add(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_add_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -79,14 +79,14 @@ func.func @broadcast_add_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t // CHECK: return %[[VAL_2]] : tensor<4x4x4x4xi32> // CHECK: } func.func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> func.return %0 : tensor<4x4x4x4xi32> } // CHECK-LABEL: func @unsupported_broadcast_add // CHECK: chlo.broadcast_add func.func @unsupported_broadcast_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> func.return %0 : tensor<4x4x4x4xi32> } @@ -122,7 +122,7 @@ func.func @broadcast_div(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_div_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -159,7 +159,7 @@ func.func @broadcast_shift_left(%arg0: tensor<1xi32>, %arg1: tensor<4xi32>) -> ( // CHECK: return %[[VAL_2]] : tensor // CHECK: } func.func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %0 : tensor } @@ -247,7 +247,7 @@ func.func @broadcast_mul(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_mul_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -268,7 +268,7 @@ func.func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -304,7 +304,7 @@ func.func @broadcast_sub(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> ( // CHECK: return %[[VAL_2]] : tensor<1x2xi32> // CHECK: } func.func @broadcast_sub_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -341,7 +341,7 @@ func.func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi // CHECK: return %[[VAL_2]] : tensor<2x4xi32> // CHECK: } func.func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> func.return %0 : tensor<2x4xi32> } @@ -363,7 +363,7 @@ func.func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -396,7 +396,7 @@ func.func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -441,7 +441,7 @@ func.func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> t // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_or_broadcast_chlo(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> func.return %0 : tensor<1x4xi8> } @@ -474,7 +474,7 @@ func.func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_xor_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1xi8>) -> tensor<1x4xi8> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>) -> tensor<1x4xi8> %1 = mhlo.xor %0, %arg1 : tensor<1x4xi8> func.return %1 : tensor<1x4xi8> } @@ -509,7 +509,7 @@ func.func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> // CHECK: return %[[VAL_2]] : tensor<1x4xi8> // CHECK: } func.func @bitwise_and_broadcast_chlo(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> func.return %0 : tensor<1x4xi8> } @@ -584,16 +584,16 @@ func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) %1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = #chlo} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %2 = mhlo.constant dense<0> : tensor<3xi32> %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = #chlo} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %8 = mhlo.constant dense<1> : tensor<3xi32> %9 = mhlo.subtract %7, %8 : tensor<3xi32> - %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = array} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = array} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> func.return %14 : tensor<2x3xi32> } @@ -623,13 +623,13 @@ func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = #mhlo} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %2 = mhlo.constant dense<0> : tensor<2x3xi32> %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = #chlo} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %8 = mhlo.constant dense<1> : tensor<2x3xi32> %9 = mhlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = array} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = mhlo.divide %11, %12 : tensor<2x3xi32> @@ -660,8 +660,8 @@ func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: return %[[VAL_4]] : tensor<2x3xf16> // CHECK: } func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> func.return %2 : tensor<2x3xf16> } @@ -707,7 +707,7 @@ func.func @equal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> te // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @equal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -718,7 +718,7 @@ func.func @equal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -736,7 +736,7 @@ func.func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: t // CHECK-LABEL: func @equal_unsupported_compare_type func.func @equal_unsupported_compare_type(%arg0: tensor<1xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xi1> { // CHECK: chlo.broadcast_compare - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, compare_type = #chlo, comparison_direction = #chlo} : (tensor<1xf32>, tensor<1x2xf32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, compare_type = #chlo, comparison_direction = #chlo} : (tensor<1xf32>, tensor<1x2xf32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -770,7 +770,7 @@ func.func @notequal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @notequal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -781,7 +781,7 @@ func.func @notequal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -826,7 +826,7 @@ func.func @broadcast_greater(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_greater_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -867,7 +867,7 @@ func.func @broadcast_greater_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32 // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_greater_equal_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -901,7 +901,7 @@ func.func @broadcast_less(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> ten // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_less_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -935,7 +935,7 @@ func.func @broadcast_less_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) // CHECK: return %[[VAL_2]] : tensor<1x2xi1> // CHECK: } func.func @broadcast_less_equal_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> func.return %0 : tensor<1x2xi1> } @@ -980,7 +980,7 @@ func.func @const() -> tensor<2xi32> { // CHECK: } func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor - %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = array} : (tensor, tensor<1xi32>) -> tensor<1xi32> func.return %1 : tensor<1xi32> } @@ -992,7 +992,7 @@ func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: } func.func @relu_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor - %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %1 : tensor } @@ -1007,8 +1007,8 @@ func.func @relu_unranked(%arg0: tensor) -> tensor { func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor - %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> - %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = array} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = array} : (tensor<1xi32>, tensor) -> tensor<1xi32> func.return %3 : tensor<1xi32> } @@ -1023,8 +1023,8 @@ func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { func.func @relu6_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor - %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = array} : (tensor, tensor) -> tensor + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = array} : (tensor, tensor) -> tensor func.return %3 : tensor } @@ -1039,7 +1039,7 @@ func.func @relu6_unranked(%arg0: tensor) -> tensor { // CHECK: } func.func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = #chlo} : (tensor, tensor) -> tensor + %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor, tensor) -> tensor %2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32> %3 = "mhlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> func.return %3 : tensor<4x8xf32> @@ -2140,6 +2140,8 @@ func.func @convert_dot_general_dynamic_contracting_dim(%arg0: tensor<4x4x?xf32>, func.return %0 : tensor<4x4x256xf32> } + + // CHECK-LABEL: func.func @convert_conv1d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { @@ -2204,7 +2206,29 @@ func.func @convert_conv1d_dynamic_batch(%arg0: tensor, %arg1: ten func.return %0 : tensor } - +// CHECK-LABEL: convert_dynamic_1d_group_conv +func.func private @convert_dynamic_1d_group_conv(%arg1: tensor, %arg2: tensor<768x48x128xf32>) -> (tensor) { + %0 = mhlo.convolution(%arg1, %arg2) + dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], + window = {pad = [[64, 64]]} + {batch_group_count = 1 : i64, feature_group_count = 16 : i64} + : (tensor, tensor<768x48x128xf32>) -> tensor + return %0 : tensor +// CHECK: %cst = arith.constant dense<[-9223372036854775808, 768, 2, 1]> : tensor<4xi64> +// CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %cst_0 = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tf.Transpose"(%0, %cst_0) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %cst_1 = arith.constant dense<[768, 48, 128, 1]> : tensor<4xi64> +// CHECK: %2 = "tf.Reshape"(%arg1, %cst_1) : (tensor<768x48x128xf32>, tensor<4xi64>) -> tensor<768x48x128x1xf32> +// CHECK: %cst_2 = "tf.Const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %3 = "tf.Transpose"(%2, %cst_2) : (tensor<768x48x128x1xf32>, tensor<4xi64>) -> tensor<128x1x48x768xf32> +// CHECK: %4 = "tf.Conv2D"(%1, %3) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 64, 64, 0, 0, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor, tensor<128x1x48x768xf32>) -> tensor +// CHECK: %cst_3 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %5 = "tf.Transpose"(%4, %cst_3) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %cst_4 = arith.constant dense<[-9223372036854775808, 768, 3]> : tensor<3xi64> +// CHECK: %6 = "tf.Reshape"(%5, %cst_4) : (tensor, tensor<3xi64>) -> tensor +// CHECK: return %6 : tensor +} // CHECK-LABEL: func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf( // CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, @@ -2318,13 +2342,22 @@ func.func @no_convert_conv1d_dynamic(%arg0: tensor<16x?x256xbf16>, %arg1: tensor func.return %0 : tensor<16x?x256xbf16> } -// CHECK-LABEL: func.func @no_convert_conv1d_feature_group_gt_1( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { -// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {stride = [1], pad = {{\[\[}}0, 0]], lhs_dilate = [1], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<16x32x256xbf16>, tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> -// CHECK: return %[[VAL_2]] : tensor<16x32x128xbf16> -// CHECK: } -func.func @no_convert_conv1d_feature_group_gt_1(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { +// CHECK-LABEL: func.func @convert_conv1d_feature_group_gt_1( +// CHECK: %cst = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64> +// CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK: %cst_0 = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tf.Transpose"(%0, %cst_0) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16> +// CHECK: %cst_1 = arith.constant dense<[1, 128, 128, 1]> : tensor<4xi64> +// CHECK: %2 = "tf.Reshape"(%arg1, %cst_1) : (tensor<1x128x128xbf16>, tensor<4xi64>) -> tensor<1x128x128x1xbf16> +// CHECK: %cst_2 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %3 = "tf.Transpose"(%2, %cst_2) : (tensor<1x128x128x1xbf16>, tensor<4xi64>) -> tensor<1x1x128x128xbf16> +// CHECK: %4 = "tf.Conv2D"(%1, %3) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x128x128xbf16>) -> tensor<16x32x1x128xbf16> +// CHECK: %cst_3 = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %5 = "tf.Transpose"(%4, %cst_3) : (tensor<16x32x1x128xbf16>, tensor<4xi64>) -> tensor<16x32x128x1xbf16> +// CHECK: %cst_4 = arith.constant dense<[16, 32, 128]> : tensor<3xi64> +// CHECK: %6 = "tf.Reshape"(%5, %cst_4) : (tensor<16x32x128x1xbf16>, tensor<3xi64>) -> tensor<16x32x128xbf16> +// CHECK: return %6 : tensor<16x32x128xbf16> +func.func @convert_conv1d_feature_group_gt_1(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir index 4e12ffd931c5f2..27e22cb524b8af 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir @@ -11,12 +11,12 @@ func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { } } -//CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} { +//CHECK: module attributes +//CHECK-SAME: keep_stablehlo_constant = "true" //CHECK-NEXT: func.func @main(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "stablehlo.dynamic_update_slice"}} { //CHECK-DAG: %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> //CHECK-DAG: %1 = stablehlo.constant dense<1> : tensor //CHECK-DAG: %2 = stablehlo.constant dense<0> : tensor //CHECK-NEXT: %3 = stablehlo.dynamic_update_slice %arg0, %0, %1, %2, %2 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> //CHECK-NEXT: return %3 : tensor<2x1x2xf32> -//CHECK-NEXT: } -//CHECK-NEXT:} \ No newline at end of file +//CHECK-NEXT: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index c87e75e8ed1ee5..fd850862f15196 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -1,12 +1,19 @@ // RUN: odml-to-stablehlo-opt --uniform-quantized-stablehlo-to-tfl \ // RUN: --split-input-file --verify-diagnostics %s | FileCheck %s -// CHECK-LABEL: uniform_quantize_op +// ============================================================================ +// The following functions tests example quantization patterns outputted from +// JAX Quantizer. JAX Quantizer should output integer types, which are +// composed into `UniformQuantized{|PerAxis}Type` via +// `compose_uniform_quantized_type_pass.cc`. +// ============================================================================ + func.func @uniform_quantize_op(%arg: tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } -// CHECK: %[[QUANT:.*]] = "tfl.quantize"({{.*}}) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK-LABEL: uniform_quantize_op +// CHECK: %[[QUANT:.+]] = "tfl.quantize"({{.*}}) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %[[QUANT]] // ----- @@ -14,11 +21,11 @@ func.func @uniform_quantize_op(%arg: tensor<2x2xf32>) -> tensor<2x2x!quant.unifo // Tests that the pattern doesn't match when the input tensor's type is a // quantized type. -// CHECK-LABEL: uniform_quantize_op_quantized_input func.func @uniform_quantize_op_quantized_input(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } +// CHECK-LABEL: uniform_quantize_op_quantized_input // CHECK: stablehlo.uniform_quantize // CHECK-NOT: tfl.quantize @@ -28,11 +35,11 @@ func.func @uniform_quantize_op_quantized_input(%arg: tensor<2x2x!quant.uniform) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } +// CHECK-LABEL: uniform_quantize_op_uint16_output // CHECK: stablehlo.uniform_quantize // CHECK-NOT: tfl.quantize @@ -42,22 +49,22 @@ func.func @uniform_quantize_op_uint16_output(%arg: tensor<2x2xf32>) -> tensor<2x // is i32. i32 storage type for quantized type is not compatible with // `tfl.quantize`. -// CHECK-LABEL: uniform_quantize_op_i32_output func.func @uniform_quantize_op_i32_output(%arg: tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } +// CHECK-LABEL: uniform_quantize_op_i32_output // CHECK: stablehlo.uniform_quantize // CHECK-NOT: tfl.quantize // ----- -// CHECK-LABEL: uniform_dequantize_op func.func @uniform_dequantize_op(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } -// CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"({{.*}}) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> +// CHECK-LABEL: uniform_dequantize_op +// CHECK: %[[DEQUANT:.+]] = "tfl.dequantize"({{.*}}) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK: return %[[DEQUANT]] // ----- @@ -66,11 +73,11 @@ func.func @uniform_dequantize_op(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } +// CHECK-LABEL: uniform_dequantize_op_ui16_storage_input // CHECK: stablehlo.uniform_dequantize // CHECK-NOT: tfl.dequantize @@ -80,11 +87,11 @@ func.func @uniform_dequantize_op_ui16_storage_input(%arg: tensor<2x2x!quant.unif // storage type is i32. i32 storage type is not compatible with // `tfl.dequantize`. -// CHECK-LABEL: uniform_dequantize_op_i32_storage_input func.func @uniform_dequantize_op_i32_storage_input(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } +// CHECK-LABEL: uniform_dequantize_op_i32_storage_input // CHECK: stablehlo.uniform_dequantize // CHECK-NOT: tfl.dequantize @@ -94,109 +101,104 @@ func.func @uniform_dequantize_op_i32_storage_input(%arg: tensor<2x2x!quant.unifo // storage type is i32. i32 storage type is not compatible with // `tfl.dequantize`. -// CHECK-LABEL: uniform_dequantize_op_return_f64 func.func @uniform_dequantize_op_return_f64(%arg: tensor<2x2x!quant.uniform>) -> tensor<2x2xf64> { %0 = stablehlo.uniform_dequantize %arg : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf64> return %0 : tensor<2x2xf64> } +// CHECK-LABEL: uniform_dequantize_op_return_f64 // CHECK: stablehlo.uniform_dequantize // CHECK-NOT: tfl.dequantize // ----- -// CHECK-LABEL: convolution_upstream_full_integer -func.func @convolution_upstream_full_integer(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +func.func @convolution_upstream_same_padding_srq(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %1 : tensor<1x3x3x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK-DAG: %[[CONST_0:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32> // Note that the quantized dimension is 0, and the shape has been transposed // to (2, 3, 3, 4). -// CHECK-DAG: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> -// CHECK-DAG: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// Explicit tfl.pad op to reflect explicit padding attribute. -// CHECK: %[[PAD:.*]] = "tfl.pad"(%[[ARG]], %[[CONST_0]]) : (tensor<1x3x3x4x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x5x5x4x!quant.uniform> -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x5x5x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> +// CHECK-LABEL: convolution_upstream_same_padding_srq +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x3x3x2x!quant.uniform> // ----- -// CHECK-LABEL: convolution_upstream_full_integer_non_const_filter -func.func @convolution_upstream_full_integer_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +func.func @convolution_upstream_srq_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %0 : tensor<1x3x3x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> - // Confirm that the `stablehlo.convolution` is not converted to `tfl.conv_2d`. +// CHECK-LABEL: convolution_upstream_srq_non_const_filter +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> // CHECK: stablehlo.convolution // CHECK-NOT: tfl.conv_2d // ----- -// Test that if the window padding contains values of 0, tfl.pad op is not +// Tests that if the window padding contains values of 0, tfl.pad op is not // created and the `padding` attribute is set as "VALID". -// CHECK-LABEL: convolution_upstream_full_integer_valid_padding -func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +func.func @convolution_upstream_srq_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 0], [0, 0]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> return %1 : tensor<1x1x1x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK-LABEL: convolution_upstream_srq_valid_padding +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> // CHECK-NOT: tfl.pad -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> // ----- -// Test that if the window padding value is missing, tfl.pad op is not +// Tests that if the window padding value is missing, tfl.pad op is not // created and the `padding` attribute is set as "VALID". -// CHECK-LABEL: convolution_upstream_full_integer_valid_padding -func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +func.func @convolution_upstream_srq_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The `window` attribute is empty. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> return %1 : tensor<1x1x1x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> -// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK-LABEL: convolution_upstream_srq_valid_padding +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> // ----- -// Test that if the window stride value is explicitly set, the attribute +// Tests that if the window stride value is explicitly set, the attribute // value is transferred to tfl.conv_2d's stridw_h and stride_w values. -// CHECK-LABEL: convolution_upstream_full_integer_strides -func.func @convolution_upstream_full_integer_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { +func.func @convolution_upstream_srq_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The stride value is explicitly set to [1, 2]. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 2], pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> return %1 : tensor<1x3x2x2x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> -// CHECK-DAG: %[[CONST:.*]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32> -// CHECK-DAG: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> -// CHECK-DAG: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> -// CHECK: %[[PAD:.*]] = "tfl.pad"(%arg0, %[[CONST]]) : (tensor<1x3x3x4x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x5x5x4x!quant.uniform> +// CHECK-LABEL: convolution_upstream_srq_strides +// CHECK-SAME: %[[ARG:.+]]: tensor<1x3x3x4x!quant.uniform> +// CHECK-DAG: %[[CONST:.+]] = "tfl.pseudo_const"(){{.*}}dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.pad"(%arg0, %[[CONST]]) : (tensor<1x3x3x4x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x5x5x4x!quant.uniform> // Tests that the stride_w is set to 2. -// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x5x5x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> +// CHECK: %[[CONV2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x5x5x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> // CHECK: return %[[CONV2D]] : tensor<1x3x2x2x!quant.uniform> // ----- -// Test full integer quantized dot_general with asymmetric quantized input. +// Tests static range quantized dot_general with asymmetric quantized input. -// CHECK-LABEL: dot_general_upstream_full_integer_asym_input -func.func @dot_general_upstream_full_integer_asym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_asym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -208,16 +210,16 @@ func.func @dot_general_upstream_full_integer_asym_input(%arg0: tensor<1x2x3x4x!q } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> -// CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK-LABEL: dot_general_upstream_srq_asym_input +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- -// Test full integer quantized dot_general with symmetric quantized input. +// Tests static range quantized dot_general with symmetric quantized input. -// CHECK-LABEL: dot_general_upstream_full_integer_sym_input -func.func @dot_general_upstream_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -230,41 +232,16 @@ func.func @dot_general_upstream_full_integer_sym_input(%arg0: tensor<1x2x3x4x!qu } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } - -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() +// CHECK-LABEL: dot_general_upstream_srq_sym_input +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() // CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} // ----- -// Tests that the pattern does not match when the output tensor's storage -// type is i32. Currently we support qi8, qi8 -> qi8 only for GEMM ops that -// are quantized upstream. Other cases should be handled by regular quantized -// stablehlo.dot_general case. +// Tests static range quantized dot_general with activation as RHS -// CHECK-LABEL: dot_general_upstream_full_integer_i32_output -func.func @dot_general_upstream_full_integer_i32_output(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { - %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> - %1 = "stablehlo.dot_general"(%arg0, %0) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [3], - rhs_contracting_dimensions = [2] - >, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> - return %1 : tensor<1x2x3x5x!quant.uniform> -} -// CHECK: stablehlo.dot_general -// CHECK-NOT: tfl.quantize - -// ----- - -// Test full integer quantized dot_general with activation as RHS - -// CHECK-LABEL: dot_general_upstream_full_integer_activation_rhs -func.func @dot_general_upstream_full_integer_activation_rhs(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_activation_rhs(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -276,14 +253,15 @@ func.func @dot_general_upstream_full_integer_activation_rhs(%arg0: tensor<1x2x3x } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %0 : tensor<1x2x3x5x!quant.uniform> } +// CHECK-LABEL: dot_general_upstream_srq_activation_rhs // CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- -// Test full integer quantized dot_general with adj_x +// Tests static range quantized dot_general with adj_x -// CHECK-LABEL: dot_general_upstream_full_integer_adj_x -func.func @dot_general_upstream_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_adj_x +func.func @dot_general_upstream_srq_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -297,17 +275,15 @@ func.func @dot_general_upstream_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant. } : (tensor<1x2x4x3x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } - -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x4x3x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x4x3x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> // CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = true, adj_y = false} // ----- -// Test full integer quantized dot_general with adj_y +// Tests static range quantized dot_general with adj_y -// CHECK-LABEL: dot_general_upstream_full_integer_adj_y -func.func @dot_general_upstream_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -321,17 +297,16 @@ func.func @dot_general_upstream_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant. } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x5x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } - -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x5x4x!quant.uniform>, value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> +// CHECK-LABEL: dot_general_upstream_srq_adj_y +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x5x4x!quant.uniform>, value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> // CHECK: "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = true} // ----- -// Test full integer quantized dot_general with wrong batch dims +// Tests static range quantized dot_general with wrong batch dims -// CHECK-LABEL: dot_general_upstream_full_integer_too_many_batches -func.func @dot_general_upstream_full_integer_too_many_batches(%arg0: tensor<1x1x1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_too_many_batches(%arg0: tensor<1x1x1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x1x1x2x4x5xi8>} : () -> tensor<1x1x1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -345,15 +320,15 @@ func.func @dot_general_upstream_full_integer_too_many_batches(%arg0: tensor<1x1x return %1 : tensor<1x1x1x2x3x5x!quant.uniform> } // Only support size(batching_dimensions) <= 3 +// CHECK-LABEL: dot_general_upstream_srq_too_many_batches // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.batch_matmul // ----- -// Test full integer quantized dot_general with too many contracting dimension +// Tests static range quantized dot_general with too many contracting dimension -// CHECK-LABEL: dot_general_upstream_full_integer_too_many_contractions -func.func @dot_general_upstream_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x4x5xi8>} : () -> tensor<1x2x4x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -367,15 +342,15 @@ func.func @dot_general_upstream_full_integer_too_many_contractions(%arg0: tensor return %1 : tensor<1x2x3x5x!quant.uniform> } // Only support size(contracting_dimensions) == 1 +// CHECK-LABEL: dot_general_upstream_srq_too_many_contractions // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.batch_matmul // ----- -// Test full integer quantized dot_general with unsupported contracting dim +// Tests static range quantized dot_general with unsupported contracting dim -// CHECK-LABEL: dot_general_upstream_full_integer_wrong_contracting -func.func @dot_general_upstream_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { +func.func @dot_general_upstream_srq_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -388,17 +363,17 @@ func.func @dot_general_upstream_full_integer_wrong_contracting(%arg0: tensor<1x2 } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> return %1 : tensor<1x4x3x5x!quant.uniform> } - // Contracting dimension must be the last two dimension +// CHECK-LABEL: dot_general_upstream_srq_wrong_contracting // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.batch_matmul // ----- -// Test full integer quantized dot_general with float operands +// Tests static range quantized dot_general with float operands -// CHECK-LABEL: dot_general_upstream_full_integer_float_operands -func.func @dot_general_upstream_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { +// CHECK-LABEL: dot_general_upstream_srq_float_operands +func.func @dot_general_upstream_srq_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -416,44 +391,44 @@ func.func @dot_general_upstream_full_integer_float_operands(%arg0: tensor<1x2x3x // ----- -// Test full integer quantized dot_general with asymmetric weight (rhs). +// Tests static range quantized dot_general with asymmetric weight (rhs). -// CHECK-LABEL: dot_general_upstream_full_integer_asym_weight -func.func @dot_general_upstream_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_asym_weight +func.func @dot_general_upstream_srq_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> } -// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x3x4x!quant.uniform> -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> -// CHECK: %[[BMM:.*]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> +// CHECK-SAME: %[[ARG:.+]]: tensor<1x2x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: %[[BMM:.+]] = "tfl.batch_matmul"(%[[ARG]], %[[QCONST_0]]) {adj_x = false, adj_y = false} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized, it is converted to `tfl.fully_connected` op. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter +func.func @dot_general_upstream_srq_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> return %1 : tensor<1x2x!quant.uniform> } -// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x!quant.uniform> +// CHECK-SAME: %[[ARG_0:.+]]: tensor<1x3x!quant.uniform> // Weight tensor is transposed, as tfl.fully_connected accepts a [o, i] matrix. -// CHECK-DAG: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x!quant.uniform>, value = dense<1> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> -// CHECK-DAG: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform:f32:0, {1.000000e+08,1.500000e+09}>>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform:f32:0, {1.000000e+08,1.500000e+09}>> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<1> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> // Bias tensor's scale is input scale * filter scale. -// CHECK: %[[FC:.*]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform>, tensor<2x!quant.uniform:f32:0, {1.000000e+08,1.500000e+09}>>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[FC:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3x!quant.uniform>, tensor<2x3x!quant.uniform:f32:0, {2.000000e+02,3.000000e+03}>>, tensor<2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK-NEXT: return %[[FC]] : tensor<1x2x!quant.uniform> // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dimension, it is not converted. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_with_batch_dim +func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x1x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> return %1 : tensor<1x1x2x!quant.uniform> @@ -465,11 +440,11 @@ func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_batc // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dim > 1, it is not converted. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_multibatch +func.func @dot_general_upstream_srq_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x3x2xi8>} : () -> tensor<3x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x1x3x!quant.uniform>, tensor<3x3x2x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> return %1 : tensor<3x1x2x!quant.uniform> @@ -481,11 +456,11 @@ func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_multibatc // ----- -// Test that when the weight tensor for `stablehlo.dot_general` is per-axis +// Tests that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has more than one contracting dimension, it is not converted. -// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims -func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_srq_per_axis_quantized_filter_with_multiple_contracting_dims +func.func @dot_general_upstream_srq_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1, 2] x [2, 1] : (tensor<1x2x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x!quant.uniform> return %1 : tensor<1x1x!quant.uniform> @@ -497,22 +472,592 @@ func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_mult // ----- -// Test that a simple per-tensor quantized stablehlo.dot_general is properly -// fused with a subsequent requantize (qi32->qi8) op then legalized. -// Supports the following format: (lhs: qi8, rhs: qi8) -> result: qi32 +// ============================================================================ +// The following functions tests example quantization patterns outputted from +// StableHLO Quantizer. These patterns should be legalized early directly +// to fused tflite ops. +// ============================================================================ + +// Tests that a simple per-tensor quantized `stablehlo.dot_general` is properly +// lowered to fused `tfl.fully_connected`. +// This case covers for the following quantization patterns because +// activation clipping ranges take affect in scale and zp of the final +// `stablehlo.uniform_quantize`. See more details in b/319168201. +// * dot_general_fn +// * dot_general_with_relu_fn +// * dot_general_with_relu6_fn + +func.func @dot_general_srq(%arg0: tensor<1x1024x!quant.uniform>) -> (tensor<1x3x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %2 : tensor<1x3x!quant.uniform> +} +// CHECK-LABEL: dot_general_srq +// CHECK-SAME: (%[[ARG_1:.+]]: tensor<1x1024x!quant.uniform) -> tensor<1x3x!quant.uniform> +// CHECK-NOT: stablehlo.dot_general +// CHECK: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform:f32, 2.000000e+00>> +// CHECK: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<0> : tensor<3xi32>} : () -> tensor<3x!quant.uniform> +// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-NOT: tfl.batch_matmul +// CHECK: return %[[FULLY_CONNECTED]] -// CHECK-LABEL: dot_general_full_integer -// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x1024x!quant.uniform - func.func @dot_general_full_integer(%arg0: tensor<1x1024x!quant.uniform> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) { - %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> - %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> - %2 = stablehlo.uniform_quantize %1 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> - %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - return %3 : tensor<1x3xf32> +// ----- + +// Tests that a fused per-tensor quantized `stablehlo.dot_general` is properly +// lowered to fused `tfl.fully_connected`. +// TODO: b/309896242 - Add more support for dynamic bias fusion cases. + +func.func @dot_general_with_bias_same_shape_srq(%arg0: tensor<1x1024x!quant.uniform>) -> (tensor<1x3x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> + %1 = stablehlo.constant() {value = dense<2> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.add %2, %1 : tensor<1x3x!quant.uniform> + %4 = stablehlo.uniform_quantize %3 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %4 : tensor<1x3x!quant.uniform> +} +// CHECK-LABEL: dot_general_with_bias_same_shape +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform:f32, 2.000000e+00>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<2> : tensor<1x3xi32>} : () -> tensor<3x!quant.uniform> +// CHECK: %[[FULLY_CONNECTED:.+]] = "tfl.fully_connected"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform:f32, 2.000000e+00>>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[FULLY_CONNECTED]] + +// ----- + +// Tests that a simple per-channel quantized `stablehlo.convolution` is properly +// lowered to fused `tfl.conv_2d`. +// This case covers for the following quantization patterns because +// activation clipping ranges take affect in scale and zp of the final +// `stablehlo.uniform_quantize`. See more details in b/319168201. +// * conv_fn +// * conv_with_relu_fn +// * conv_with_relu6_fn + +func.func @conv_srq(%arg0: tensor<1x5x5x2x!quant.uniform>) -> (tensor<1x6x6x4x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<3> : tensor<2x2x2x4xi8>} : () -> tensor<2x2x2x4x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2x!quant.uniform>, tensor<2x2x2x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x6x6x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + return %2 : tensor<1x6x6x4x!quant.uniform> +} +// CHECK-LABEL: func.func @conv_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x5x5x2x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x2x2x2xi8>} : () -> tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_2:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x!quant.uniform>, value = dense<0> : tensor<4xi32>} : () -> tensor<4x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[QCONST_0]]) : (tensor<1x5x5x2x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x7x7x2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_1]], %[[QCONST_2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x2x!quant.uniform>, tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK: return %[[CONV_2D]] + +func.func @conv_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.uniform>) -> (tensor<1x32x32x2x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x3x2xi8>} : () -> tensor<3x3x3x2x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x32x32x3x!quant.uniform>, tensor<3x3x3x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x32x32x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + return %2 : tensor<1x32x32x2x!quant.uniform> +} +// CHECK-LABEL: func.func @conv_same_padding_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x32x32x3x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>} : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK: return %[[CONV_2D]] : tensor<1x32x32x2x!quant.uniform> + +// ----- + +// Tests that a fused per-channel quantized `stablehlo.convolution` is properly +// lowered to fused `tfl.conv_2d`. +// This case covers for the following quantization patterns because +// activation clipping ranges take affect in scale and zp of the final +// `stablehlo.uniform_quantize`. See more details in b/319168201. +// * conv_with_bias_fn +// * conv_with_bias_and_relu_fn +// * conv_with_bias_and_relu6_fn + +func.func @conv_with_bias_and_relu_srq(%arg0: tensor<1x5x5x2x!quant.uniform>) -> (tensor<1x6x6x4x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<5> : tensor<1x1x1x4xi32>} : () -> tensor<1x1x1x4x!quant.uniform> + %1 = stablehlo.constant() {value = dense<3> : tensor<2x2x2x4xi8>} : () -> tensor<2x2x2x4x!quant.uniform> + %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2, 3] : (tensor<1x1x1x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + %3 = stablehlo.convolution(%arg0, %1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x5x5x2x!quant.uniform>, tensor<2x2x2x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + %4 = stablehlo.add %3, %2 : tensor<1x6x6x4x!quant.uniform> + %5 = stablehlo.uniform_quantize %4 : (tensor<1x6x6x4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> + return %5 : tensor<1x6x6x4x!quant.uniform> } +// CHECK-LABEL: func.func @conv_with_bias_and_relu_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x5x5x2x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK-DAG: %[[CONST_0:.+]] = "tfl.pseudo_const"() {value = dense<{{\[\[0, 0\], \[1, 1\], \[1, 1\], \[0, 0\]\]}}> : tensor<4x2xi32>} : () -> tensor<4x2xi32> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<4x2x2x2xi8>} : () -> tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_2:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<4x!quant.uniform>, value = dense<5> : tensor<1x1x1x4xi32>} : () -> tensor<4x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.pad"(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x5x5x2x!quant.uniform>, tensor<4x2xi32>) -> tensor<1x7x7x2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[PAD]], %[[QCONST_1]], %[[QCONST_2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x2x!quant.uniform>, tensor<4x2x2x2x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00,3.000000e+00,3.000000e+00}>>, tensor<4x!quant.uniform>) -> tensor<1x6x6x4x!quant.uniform> +// CHECK: return %[[CONV_2D]] + +func.func @conv_with_bias_same_padding_srq(%arg0: tensor<1x32x32x3x!quant.uniform>) -> (tensor<1x32x32x2x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x3x2xi8>} : () -> tensor<3x3x3x2x!quant.uniform> + %1 = stablehlo.constant() {value = dense<5> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform> + %2 = stablehlo.broadcast_in_dim %1, dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x32x32x3x!quant.uniform>, tensor<3x3x3x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + %4 = stablehlo.add %3, %2 : tensor<1x32x32x2x!quant.uniform> + %5 = stablehlo.uniform_quantize %4 : (tensor<1x32x32x2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> + return %5 : tensor<1x32x32x2x!quant.uniform> +} +// CHECK-LABEL: func.func @conv_with_bias_same_padding_srq +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x32x32x3x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK-DAG: %[[QCONST_0:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, value = dense<3> : tensor<2x3x3x3xi8>} : () -> tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>> +// CHECK-DAG: %[[QCONST_1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<5> : tensor<1x1x1x2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV_2D:.+]] = "tfl.conv_2d"(%[[ARG_0]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x3x!quant.uniform>, tensor<2x3x3x3x!quant.uniform:f32:0, {3.000000e+00,3.000000e+00}>>, tensor<2x!quant.uniform>) -> tensor<1x32x32x2x!quant.uniform> +// CHECK: return %[[CONV_2D]] -// CHECK-NOT: stablehlo.dot_general -// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform> -// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform:f32, 2.000000e+00>>, value = dense<0> : tensor<3xi32>} : () -> tensor<3x!quant.uniform:f32, 2.000000e+00>> -// CHECK: "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform>, tensor<3x!quant.uniform:f32, 2.000000e+00>>) -> tensor<1x3x!quant.uniform> -// CHECK-NOT: tfl.batch_matmul +// ----- + +// Tests that a quantized stablehlo.transpose is converted to tfl.transpose. + +func.func @transpose( + %arg0: tensor<2x3x4x!quant.uniform> + ) -> tensor<4x3x2x!quant.uniform> { + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x3x4x!quant.uniform>) -> tensor<4x3x2x!quant.uniform> + return %0 : tensor<4x3x2x!quant.uniform> +} +// CHECK-LABEL: transpose +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x!quant.uniform> +// CHECK-NOT: stablehlo.transpose +// CHECK: %[[CST:.+]] = arith.constant dense<[2, 1, 0]> : tensor<3xi32> +// CHECK: %[[TRANSPOSE:.+]] = "tfl.transpose"(%[[ARG0]], %[[CST]]) : (tensor<2x3x4x!quant.uniform>, tensor<3xi32>) -> tensor<4x3x2x!quant.uniform> +// CHECK: return %[[TRANSPOSE]] + +// ----- + +// Tests that a float stablehlo.transpose is not converted to tfl.transpose. + +func.func @float_transpose(%arg0: tensor<2x3x4xf32>) -> tensor<4x3x2xf32> { + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x3x4xf32>) -> tensor<4x3x2xf32> + return %0 : tensor<4x3x2xf32> +} +// CHECK-LABEL: float_transpose +// CHECK-NOT: tfl.transpose +// CHECK: stablehlo.transpose + +// ----- + +// Tests that a quantized stablehlo.reshape is converted to tfl.reshape. + +func.func @reshape( + %arg0: tensor<2x3x4x!quant.uniform> + ) -> tensor<6x4x!quant.uniform> { + %0 = stablehlo.reshape %arg0 : (tensor<2x3x4x!quant.uniform>) -> tensor<6x4x!quant.uniform> + return %0 : tensor<6x4x!quant.uniform> +} +// CHECK-LABEL: reshape +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x!quant.uniform> +// CHECK-NOT: stablehlo.reshape +// CHECK: %[[CST:.+]] = arith.constant dense<[6, 4]> : tensor<2xi32> +// CHECK: %[[RESHAPE:.+]] = "tfl.reshape"(%[[ARG0]], %[[CST]]) : (tensor<2x3x4x!quant.uniform>, tensor<2xi32>) -> tensor<6x4x!quant.uniform> +// CHECK: return %[[RESHAPE]] + +// ----- + +// Tests that a float stablehlo.reshape is not converted to tfl.reshape. + +func.func @float_reshape(%arg0: tensor<2x3x4xf32>) -> tensor<6x4xf32> { + %0 = stablehlo.reshape %arg0 : (tensor<2x3x4xf32>) -> tensor<6x4xf32> + return %0 : tensor<6x4xf32> +} +// CHECK-LABEL: float_reshape +// CHECK-NOT: tfl.reshape +// CHECK: stablehlo.reshape + +// ----- + +// Tests that a quantized stablehlo.select is converted to tfl.select_v2. + +func.func @select( + %arg0: tensor<1x3xi1>, + %arg1: tensor<1x3x!quant.uniform>, + %arg2: tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> { + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : ( + tensor<1x3xi1>, + tensor<1x3x!quant.uniform>, + tensor<1x3x!quant.uniform> + ) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} +// CHECK-LABEL: select +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3xi1>, %[[ARG1:.+]]: tensor<1x3x!quant.uniform>, %[[ARG2:.+]]: tensor<1x3x!quant.uniform> +// CHECK-NOT: stablehlo.select +// CHECK: %[[SELECT:.+]] = "tfl.select_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (tensor<1x3xi1>, tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[SELECT]] + +// ----- + +// Tests that a float stablehlo.select is not converted to tfl.select_v2. + +func.func @float_select(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> { + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} +// CHECK-LABEL: float_select +// CHECK-NOT: tfl.select_v2 +// CHECK: stablehlo.select + +// ----- + +// Tests that a quantized stablehlo.concatenate is converted to tfl.concatenation. + +func.func @concatenate( + %arg0: tensor<3x2x!quant.uniform>, + %arg1: tensor<1x2x!quant.uniform> + ) -> tensor<4x2x!quant.uniform> { + %0 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : ( + tensor<3x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<4x2x!quant.uniform> + return %0 : tensor<4x2x!quant.uniform> +} +// CHECK-LABEL: concatenate +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x!quant.uniform>, %[[ARG1:.+]]: tensor<1x2x!quant.uniform> +// CHECK-NOT: stablehlo.concatenate +// CHECK: %[[CONCAT:.+]] = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<3x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<4x2x!quant.uniform> +// CHECK: return %[[CONCAT]] + +// ----- + +// Tests that a float stablehlo.concatenate is not converted to tfl.concatenation. + +func.func @float_concatenate(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x2xf32> { + %0 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> + return %0 : tensor<4x2xf32> +} +// CHECK-LABEL: float_concatenate +// CHECK-NOT: tfl.concatenation +// CHECK: stablehlo.concatenate + +// ----- + +// Tests that a quantized stablehlo.pad without interior padding is converted to +// tfl.padv2. + +func.func @pad_without_interior_padding( + %arg0: tensor<2x3x!quant.uniform>, + %arg1: tensor> + ) -> tensor<4x5x!quant.uniform> { + %0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [0, 0] : ( + tensor<2x3x!quant.uniform>, + tensor> + ) -> tensor<4x5x!quant.uniform> + return %0 : tensor<4x5x!quant.uniform> +} +// CHECK-LABEL: pad_without_interior_padding +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x!quant.uniform> +// CHECK-SAME: %[[ARG1:.+]]: tensor> +// CHECK: %[[PADDING:.+]] = arith.constant +// CHECK{LITERAL}: dense<[[0, 2], [1, 1]]> : tensor<2x2xi32> +// CHECK: %[[PAD:.+]] = "tfl.padv2"(%[[ARG0]], %[[PADDING]], %[[ARG1]]) : (tensor<2x3x!quant.uniform>, tensor<2x2xi32>, tensor>) -> tensor<4x5x!quant.uniform> +// CHECK: return %[[PAD]] + +// ----- + +// Tests that a quantized stablehlo.pad with interior padding is converted to +// tfl.dilate and tfl.padv2. + +func.func @pad_with_interior_padding( + %arg0: tensor<2x3x!quant.uniform>, + %arg1: tensor> + ) -> tensor<5x9x!quant.uniform> { + %0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [1, 2] : ( + tensor<2x3x!quant.uniform>, + tensor> + ) -> tensor<5x9x!quant.uniform> + return %0 : tensor<5x9x!quant.uniform> +} +// CHECK-LABEL: pad_with_interior_padding +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x!quant.uniform> +// CHECK-SAME: %[[ARG1:.+]]: tensor> +// CHECK: %[[PADDING:.+]] = arith.constant +// CHECK{LITERAL}: dense<[[0, 2], [1, 1]]> : tensor<2x2xi32> +// CHECK: %[[INTERIOR:.+]] = arith.constant +// CHECK{LITERAL}: dense<[1, 2]> : tensor<2xi32> +// CHECK: %[[DILATE:.+]] = "tfl.dilate"(%[[ARG0]], %[[INTERIOR]], %[[ARG1]]) : (tensor<2x3x!quant.uniform>, tensor<2xi32>, tensor>) -> tensor<3x7x!quant.uniform> +// CHECK: %[[PAD:.+]] = "tfl.padv2"(%[[DILATE]], %[[PADDING]], %[[ARG1]]) : (tensor<3x7x!quant.uniform>, tensor<2x2xi32>, tensor>) -> tensor<5x9x!quant.uniform> +// CHECK: return %[[PAD]] + +// ----- + +// Tests that a float stablehlo.pad is not converted to tfl.padv2. + +func.func @float_pad(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<4x5xf32> { + %0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [0, 0] : (tensor<2x3xf32>, tensor) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} +// CHECK-LABEL: float_pad +// CHECK-NOT: tfl.padv2 +// CHECK: stablehlo.pad + +// ----- + +// Tests that a quantized stablehlo.slice is converted to tfl.slice when stride +// is 1. + +func.func @slice( + %arg0: tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : ( + tensor<3x4x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} +// CHECK-LABEL: slice +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x4x!quant.uniform> +// CHECK-DAG: %[[START:.+]] = arith.constant dense<{{\[1, 2\]}}> : tensor<2xi32> +// CHECK-DAG: %[[SIZE:.+]] = arith.constant dense<2> : tensor<2xi32> +// CHECK: %[[SLICE:.+]] = "tfl.slice"(%[[ARG0]], %[[START]], %[[SIZE]]) : (tensor<3x4x!quant.uniform>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[SLICE]] + +// ----- + +// Tests that a quantized stablehlo.slice is converted to tfl.strided_slice when +// stride is not 1. + +func.func @strided_slice( + %arg0: tensor<3x6x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : ( + tensor<3x6x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} +// CHECK-LABEL: strided_slice +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x6x!quant.uniform> +// CHECK: %[[START:.+]] = arith.constant +// CHECK{LITERAL}: dense<[0, 2]> : tensor<2xi32> +// CHECK: %[[SIZE:.+]] = arith.constant +// CHECK{LITERAL}: dense<[3, 4]> : tensor<2xi32> +// CHECK: %[[STRIDE:.+]] = arith.constant +// CHECK{LITERAL}: dense<[2, 3]> : tensor<2xi32> +// CHECK: %[[SLICE:.+]] = "tfl.strided_slice"(%[[ARG0]], %[[START]], %[[SIZE]], %[[STRIDE]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<3x6x!quant.uniform>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[SLICE]] + +// ----- + +// Tests that a float stablehlo.slice is not converted to tfl.slice. + +func.func @float_slice(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +// CHECK-LABEL: float_slice +// CHECK-NOT: tfl.slice +// CHECK-NOT: tfl.strided_slice +// CHECK: stablehlo.slice + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.broadcast_to. + +func.func @broadcast_in_dim( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<3x2x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> + return %0 : tensor<3x2x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x!quant.uniform> +// CHECK: %[[SHAPE:.+]] = arith.constant +// CHECK{LITERAL}: dense<[3, 2]> : tensor<2xi32> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[ARG0]], %[[SHAPE]]) : (tensor<1x2x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.transpose and tfl.broadcast_to when broadcast_dimensions is not in +// ascending order. + +func.func @broadcast_in_dim_with_transpose( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<2x3x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2x!quant.uniform>) -> tensor<2x3x!quant.uniform> + return %0 : tensor<2x3x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim_with_transpose +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x!quant.uniform> +// CHECK: %[[BROADCAST_DIM:.+]] = arith.constant +// CHECK{LITERAL}: dense<[2, 3]> : tensor<2xi32> +// CHECK: %[[PERM:.+]] = arith.constant +// CHECK{LITERAL}: dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[TRANSPOSE:.+]] = "tfl.transpose"(%[[ARG0]], %[[PERM]]) : (tensor<1x2x!quant.uniform>, tensor<2xi32>) -> tensor<2x1x!quant.uniform> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[TRANSPOSE]], %[[BROADCAST_DIM]]) : (tensor<2x1x!quant.uniform>, tensor<2xi32>) -> tensor<2x3x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.expand_dims and tfl.broadcast_to when input rank is smaller than output +// rank. + +func.func @broadcast_in_dim_with_expand( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<3x2x1x1x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2x!quant.uniform>) -> tensor<3x2x1x1x!quant.uniform> + return %0 : tensor<3x2x1x1x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim_with_expand +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x!quant.uniform> +// CHECK-DAG: %[[BROADCAST_DIM:.+]] = arith.constant dense<{{\[3, 2, 1, 1\]}}> : tensor<4xi32> +// CHECK-DAG: %[[EXPAND_DIM1:.+]] = arith.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[EXPAND_DIM0:.+]] = arith.constant dense<2> : tensor<1xi32> +// CHECK: %[[EXPAND0:.+]] = "tfl.expand_dims"(%[[ARG0]], %[[EXPAND_DIM0]]) : (tensor<1x2x!quant.uniform>, tensor<1xi32>) -> tensor<1x2x1x!quant.uniform> +// CHECK: %[[EXPAND1:.+]] = "tfl.expand_dims"(%[[EXPAND0]], %[[EXPAND_DIM1]]) : (tensor<1x2x1x!quant.uniform>, tensor<1xi32>) -> tensor<1x2x1x1x!quant.uniform> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[EXPAND1]], %[[BROADCAST_DIM]]) : (tensor<1x2x1x1x!quant.uniform>, tensor<4xi32>) -> tensor<3x2x1x1x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a quantized stablehlo.broadcast_in_dim is converted to +// tfl.transpose, tfl.expand_dims and tfl.broadcast_to when broadcast_dimensions +// is not in ascending order and input rank is smaller than output rank. + +func.func @broadcast_in_dim_with_transpose_and_expand( + %arg0: tensor<2x3x4x!quant.uniform> + ) -> tensor<3x2x1x1x4x!quant.uniform> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<2x3x4x!quant.uniform>) -> tensor<3x2x1x1x4x!quant.uniform> + return %0 : tensor<3x2x1x1x4x!quant.uniform> +} +// CHECK-LABEL: broadcast_in_dim_with_transpose_and_expand +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x!quant.uniform> +// CHECK-DAG: %[[BROADCAST_DIM:.+]] = arith.constant dense<{{\[3, 2, 1, 1, 4\]}}> : tensor<5xi32> +// CHECK-DAG: %[[EXPAND_DIM1:.+]] = arith.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[EXPAND_DIM0:.+]] = arith.constant dense<2> : tensor<1xi32> +// CHECK-DAG: %[[PERM:.+]] = arith.constant dense<{{\[1, 0, 2\]}}> : tensor<3xi32> +// CHECK: %[[TRANSPOSE:.+]] = "tfl.transpose"(%[[ARG0]], %[[PERM]]) : (tensor<2x3x4x!quant.uniform>, tensor<3xi32>) -> tensor<3x2x4x!quant.uniform> +// CHECK: %[[EXPAND0:.+]] = "tfl.expand_dims"(%[[TRANSPOSE]], %[[EXPAND_DIM0]]) : (tensor<3x2x4x!quant.uniform>, tensor<1xi32>) -> tensor<3x2x1x4x!quant.uniform> +// CHECK: %[[EXPAND1:.+]] = "tfl.expand_dims"(%[[EXPAND0]], %[[EXPAND_DIM1]]) : (tensor<3x2x1x4x!quant.uniform>, tensor<1xi32>) -> tensor<3x2x1x1x4x!quant.uniform> +// CHECK: %[[BROADCAST:.+]] = "tfl.broadcast_to"(%[[EXPAND1]], %[[BROADCAST_DIM]]) : (tensor<3x2x1x1x4x!quant.uniform>, tensor<5xi32>) -> tensor<3x2x1x1x4x!quant.uniform> +// CHECK: return %[[BROADCAST]] + +// ----- + +// Tests that a float stablehlo.broadcast_in_dim is not converted to tfl.broadcast_to. + +func.func @float_broadcast_in_dim(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<1x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// CHECK-LABEL: float_broadcast_in_dim +// CHECK-NOT: tfl.broadcast_to +// CHECK-NOT: tfl.transpose +// CHECK-NOT: tfl.expand_dims +// CHECK: stablehlo.broadcast_in_dim + +// ----- + +// Test that a quantized stablehlo.reduce_window with max is converted to +// tfl.max_pool_2d. + +func.func @reduce_window_with_max( + %arg0: tensor<2x9x10x3x!quant.uniform>, + %arg1: tensor> +) -> tensor<2x4x3x3x!quant.uniform> { + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.maximum %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + }) {window_dimensions = array, window_strides = array} : (tensor<2x9x10x3x!quant.uniform>, tensor>) -> tensor<2x4x3x3x!quant.uniform> + return %0 : tensor<2x4x3x3x!quant.uniform> +} + +// CHECK-LABEL: reduce_window_with_max +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x9x10x3x!quant.uniform> +// CHECK-SAME: %[[ARG1:.*]]: tensor> +// CHECK: %[[MAX_POOL:.*]] = "tfl.max_pool_2d"(%[[ARG0]]) +// CHECK-SAME: {filter_height = 3 : i32, filter_width = 4 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 3 : i32} +// CHECK-SAME: (tensor<2x9x10x3x!quant.uniform>) -> tensor<2x4x3x3x!quant.uniform> +// CHECK: return %[[MAX_POOL]] + +// ----- + +// Test that a quantized stablehlo.reduce_window with max whose rank is not 4 +// is not converted to tfl.max_pool_2d. + +func.func @reduce_window_not_4d( + %arg0: tensor<3x2x9x10x3x!quant.uniform>, + %arg1: tensor> +) -> tensor<3x2x4x3x3x!quant.uniform> { + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.maximum %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + }) {window_dimensions = array, window_strides = array} : (tensor<3x2x9x10x3x!quant.uniform>, tensor>) -> tensor<3x2x4x3x3x!quant.uniform> + return %0 : tensor<3x2x4x3x3x!quant.uniform> +} + +// CHECK-LABEL: reduce_window_not_4d +// CHECK: stablehlo.reduce_window +// CHECK-NOT: tfl.max_pool_2d + +// ----- + +// Test that a quantized stablehlo.reduce_window with max that takes multiple +// inputs is not converted to tfl.max_pool_2d. + +func.func @reduce_window_not_binary( + %arg0: tensor<3x2x9x10x3x!quant.uniform>, + %arg1: tensor<3x2x9x10x3x!quant.uniform>, + %arg2: tensor>, + %arg3: tensor> +) -> tensor<3x2x4x3x3x!quant.uniform> { + %0, %1 = "stablehlo.reduce_window"(%arg0, %arg1, %arg2, %arg3) ({ + ^bb0(%arg4: tensor>, %arg5: tensor>, %arg6: tensor>, %arg7: tensor>): + %2 = stablehlo.maximum %arg4, %arg5 : tensor> + %3 = stablehlo.maximum %arg6, %arg7 : tensor> + stablehlo.return %2, %3 : tensor>, tensor> + }) {window_dimensions = array, window_strides = array} : (tensor<3x2x9x10x3x!quant.uniform>, tensor<3x2x9x10x3x!quant.uniform>, tensor>, tensor>) -> (tensor<3x2x4x3x3x!quant.uniform>, tensor<3x2x4x3x3x!quant.uniform>) + return %0 : tensor<3x2x4x3x3x!quant.uniform> +} + +// CHECK-LABEL: reduce_window_not_binary +// CHECK: stablehlo.reduce_window +// CHECK-NOT: tfl.max_pool_2d + +// ----- + +// Test that a float stablehlo.reduce_window with max is not converted to +// tfl.max_pool_2d. + +func.func @float_reduce_window_with_max( + %arg0: tensor<2x9x10x3xf32>, + %arg1: tensor +) -> tensor<2x4x3x3xf32> { + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.maximum %arg2, %arg3 : tensor + stablehlo.return %1 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<2x9x10x3xf32>, tensor) -> tensor<2x4x3x3xf32> + return %0 : tensor<2x4x3x3xf32> +} + +// CHECK-LABEL: float_reduce_window_with_max +// CHECK: stablehlo.reduce_window +// CHECK-NOT: tfl.max_pool_2d diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index 4ddfdb0d33ff05..15481b9a0a1ad2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -38,7 +38,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #define DEBUG_TYPE "stablehlo-compose-uniform-quantized-type" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 26d45005c835a8..000f88639240f3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -482,10 +482,15 @@ class Convert1DConvOp : public OpConversionPattern { const int64_t input_channels = conv_op.getLhs().getType().cast().getDimSize( input_feature_dimension); + const int kernel_input_feature_dimension = + dnums.getKernelInputFeatureDimension(); + const int kernel_input_channels = + conv_op.getRhs().getType().cast().getDimSize( + kernel_input_feature_dimension); const int64_t feature_group_count = conv_op.getFeatureGroupCount(); - if (feature_group_count != 1 && feature_group_count != input_channels) - return rewriter.notifyMatchFailure(conv_op, - "Group convolution is not supported,"); + if (feature_group_count != input_channels / kernel_input_channels || + input_channels % kernel_input_channels != 0) + return failure(); // // Transpose and reshape the input and kernel @@ -498,6 +503,7 @@ class Convert1DConvOp : public OpConversionPattern { image_2d_shape.push_back(1); auto image_2d_type = RankedTensorType::get(image_2d_shape, image_type.getElementType()); + auto loc = conv_op.getLoc(); auto image_2d_op = rewriter.create( conv_op.getLoc(), image_2d_type, conv_op.getLhs()); @@ -509,8 +515,8 @@ class Convert1DConvOp : public OpConversionPattern { auto image_permutation_and_shape = GetPermutationAndTransposedShape( image_permutation, image_2d_type, rewriter); auto transposed_image_2d_op = rewriter.create( - conv_op.getLoc(), image_permutation_and_shape.shape, - image_2d_op->getResult(0), image_permutation_and_shape.permutation); + loc, image_permutation_and_shape.shape, image_2d_op->getResult(0), + image_permutation_and_shape.permutation); // Reshape kernel to add a new spatial dimension. auto kernel_type = conv_op.getRhs().getType().cast(); @@ -521,8 +527,8 @@ class Convert1DConvOp : public OpConversionPattern { kernel_2d_shape.push_back(1); auto kernel_2d_type = RankedTensorType::get(kernel_2d_shape, kernel_type.getElementType()); - auto kernel_2d_op = rewriter.create( - conv_op.getLoc(), kernel_2d_type, conv_op.getRhs()); + auto kernel_2d_op = + rewriter.create(loc, kernel_2d_type, conv_op.getRhs()); // Transpose kernel to get it into WHIO form (where H is the added dim). SmallVector kernel_permutation = { @@ -533,8 +539,8 @@ class Convert1DConvOp : public OpConversionPattern { auto kernel_permutation_and_shape = GetPermutationAndTransposedShape( kernel_permutation, kernel_2d_type, rewriter); auto transposed_kernel_2d_op = rewriter.create( - conv_op.getLoc(), kernel_permutation_and_shape.shape, - kernel_2d_op->getResult(0), kernel_permutation_and_shape.permutation); + loc, kernel_permutation_and_shape.shape, kernel_2d_op->getResult(0), + kernel_permutation_and_shape.permutation); // // Create 2d equivalents for 1d convolution attributes. @@ -624,11 +630,11 @@ class Convert1DConvOp : public OpConversionPattern { .shape; auto conv2d_op = rewriter.create( - conv_op.getLoc(), transposed_output_2d_shape, - transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(), - window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d, - window_reversal_2d, dnums_2d, conv_op.getFeatureGroupCount(), - conv_op.getBatchGroupCount(), conv_op.getPrecisionConfigAttr()); + loc, transposed_output_2d_shape, transposed_image_2d_op.getResult(), + transposed_kernel_2d_op.getResult(), window_strides_2d, padding_2d, + lhs_dilation_2d, rhs_dilation_2d, window_reversal_2d, dnums_2d, + conv_op.getFeatureGroupCount(), conv_op.getBatchGroupCount(), + conv_op.getPrecisionConfigAttr()); OpResult conv2d_output = conv2d_op->getResult(0); auto conv2d_output_type = conv2d_output.getType().cast(); @@ -642,7 +648,7 @@ class Convert1DConvOp : public OpConversionPattern { auto output_permutation_and_shape = GetInversePermutationAndShape( output_permutation, conv2d_output_type, rewriter); auto transposed_output_2d_op = rewriter.create( - conv_op.getLoc(), output_permutation_and_shape.shape, conv2d_output, + loc, output_permutation_and_shape.shape, conv2d_output, output_permutation_and_shape.permutation); // Drop the trailing spatial dimension from the output. @@ -734,13 +740,13 @@ class ConvertNonTrivialConvOp // Mirror the filter in the spatial dimensions. mlir::Value reverse_filter_in = conv_op.getRhs(); - // If the kernel is with format [0,1,i,o] we transpose it to [0,1,o,i] - // as the TF->TFL pass anticipates this and the kernel format information - // will be lost once we legalize to TF - if (isKernelFormatHWIO(dnums)) { + // If the kernel is with format anythoing other than HWOI, we + // transpose it to [0,1,o,i] as the TF->TFL pass anticipates this and the + // kernel format information will be lost once we legalize to TF + if (!isKernelFormatHWOI(dnums)) { SmallVector permutation; - for (int64_t dim : dnums.getInputSpatialDimensions()) { - permutation.push_back(dim - 1); + for (int64_t dim : dnums.getKernelSpatialDimensions()) { + permutation.push_back(dim); } permutation.push_back(dnums.getKernelOutputFeatureDimension()); permutation.push_back(dnums.getKernelInputFeatureDimension()); @@ -753,10 +759,11 @@ class ConvertNonTrivialConvOp permutation)); reverse_filter_in = filter_transposed; } - mhlo::ReverseOp filter; - filter = rewriter.create( - conv_op.getLoc(), reverse_filter_in, - rewriter.getI64TensorAttr(dnums.getKernelSpatialDimensions())); + + // Lets hard-code the reverse indexes to be {0, 1} as the expectation is + // that the kernel is always in HWOI format, with the above code. + mhlo::ReverseOp filter = rewriter.create( + conv_op.getLoc(), reverse_filter_in, rewriter.getI64TensorAttr({0, 1})); // if output is not in [b, 0, 1, f] format, insert transpose to go back if (dnums.getOutputBatchDimension() != 0 || @@ -914,12 +921,6 @@ class ConvertNonTrivialConvOp return true; } - bool isKernelFormatHWIO(mhlo::ConvDimensionNumbersAttr dnums) const { - int64_t num_spatial_dims = dnums.getKernelSpatialDimensions().size(); - return dnums.getKernelInputFeatureDimension() == num_spatial_dims && - dnums.getKernelOutputFeatureDimension() == num_spatial_dims + 1; - } - bool isKernelFormatHWOI(mhlo::ConvDimensionNumbersAttr dnums) const { int64_t num_spatial_dims = dnums.getKernelSpatialDimensions().size(); return dnums.getKernelInputFeatureDimension() == num_spatial_dims + 1 && @@ -969,17 +970,6 @@ class ConvertNonTrivialConvOp "doesn't support negative pads"); } - // Checks kernel dimensions. - if (!isKernelFormatHWIO(dnums) && !isKernelFormatHWOI(dnums)) - return rewriter.notifyMatchFailure( - conv_op, "requires kernel format [0, 1, o, i] or [0, 1, i, o]"); - auto kernel_spatial_dimensions = dnums.getKernelSpatialDimensions(); - for (auto p : llvm::enumerate(kernel_spatial_dimensions)) { - if (p.value() != p.index()) - return rewriter.notifyMatchFailure( - conv_op, "requires kernel format [0, 1, o, i] or [0, 1, i, o]"); - } - return success(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td index 1607343ab15879..fe988ba9b20265 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td @@ -23,7 +23,7 @@ include "mhlo/IR/hlo_ops.td" // Check if broadcasting is compatible with TF ops. def IsLegalNumpyRankedBroadcast : - Constraint, + Constraint{})">, "broadcasting should be compatible with TF ops">; // Return a constant op that carries the shape of the given value. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc index 6e7ebdec6058dc..5f04704d54ef78 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc @@ -20,19 +20,26 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo @@ -48,12 +55,11 @@ static constexpr char kShardingAttr[] = "mhlo.sharding"; static constexpr char kShardingName[] = "Sharding"; class RemoveCustomCallWithSharding - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite( - mlir::stablehlo::CustomCallOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::CustomCallOp op, + PatternRewriter &rewriter) const override { // Removes the custom call with sharding op if the operand type is the // same as the result type. if (op->hasAttr(kShardingAttr) && op.getCallTargetName() == kShardingName && @@ -61,15 +67,15 @@ class RemoveCustomCallWithSharding op.getOperands().front().getType() == op.getResults().front().getType()) { rewriter.replaceOp(op, op.getOperands()); - return mlir::success(); + return success(); } - return mlir::failure(); + return failure(); } }; namespace { -bool IsShloMainFuncOp(mlir::func::FuncOp func_op) { +bool IsShloMainFuncOp(func::FuncOp func_op) { if (func_op == nullptr) { return false; } @@ -86,32 +92,55 @@ bool IsShloMainFuncOp(mlir::func::FuncOp func_op) { return true; } +// Returns true if XlaCallModuleOp has the "platform index argument". The +// platform index argument is an extra 0-dimensional i32 tensor argument at +// index 0 when the XlaCallModuleOp contains more than one platform specified at +// the "platform" attribute. +// +// See: +// https://github.com/tensorflow/tensorflow/blob/eba24f41ba9d661d2f58a515921720cf90708cd4/tensorflow/compiler/tf2xla/ops/xla_ops.cc#L1376-L1385 +bool ContainsPlatformIndexArg(TF::XlaCallModuleOp xla_call_module_op) { + return xla_call_module_op.getPlatforms().size() > 1; +} + +// Removes the platform index argument from the function. It is equivalent to +// removing the first argument from `func_op` (see the comments at +// `ContainsPlatformIndexArg`). This function assumes that `func_op` is a valid +// function deserialized from XlaCallModule op. +void RemovePlatformIndexArg(MLIRContext *ctx, func::FuncOp func_op) { + // If there are multiple platforms, the first argument is reserved for + // passing the platform index. + FunctionType function_type = func_op.getFunctionType(); + ArrayRef new_input_types = + function_type.getInputs().take_back(func_op.getNumArguments() - 1); + func_op.setFunctionType( + FunctionType::get(ctx, new_input_types, function_type.getResults())); + func_op.getBody().eraseArgument(0); +} + } // namespace -class ConvertTFXlaCallModuleOp - : public mlir::OpRewritePattern { +class ConvertTFXlaCallModuleOp : public OpRewritePattern { public: explicit ConvertTFXlaCallModuleOp(MLIRContext *context, ModuleOp module_op) - : OpRewritePattern(context), - module_op_(module_op) {} - using OpRewritePattern::OpRewritePattern; + : OpRewritePattern(context), module_op_(module_op) {} + using OpRewritePattern::OpRewritePattern; private: ModuleOp module_op_; - mlir::LogicalResult matchAndRewrite( - mlir::TF::XlaCallModuleOp op, PatternRewriter &rewriter) const override { - mlir::OwningOpRef stablehlo_module_op = - mlir::stablehlo::deserializePortableArtifact(op.getModuleAttr(), - getContext()); + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter &rewriter) const override { + OwningOpRef stablehlo_module_op = + stablehlo::deserializePortableArtifact(op.getModuleAttr(), + getContext()); if (stablehlo_module_op.get() == nullptr) { - return mlir::failure(); + return failure(); } SymbolTable parent_module_symbol_table(module_op_); SymbolTable stablehlo_module_symbol_table(stablehlo_module_op.get()); { - auto main_func_op = - stablehlo_module_symbol_table.lookup( - kStablehloModuleDefaultEntryFuncName); + auto main_func_op = stablehlo_module_symbol_table.lookup( + kStablehloModuleDefaultEntryFuncName); // TODO(b/291988976): move enforcement of this variable outside of this // rewrite pattern such that it's only checked once. Currently, this // approach results in duplicate error messages as this pattern executes @@ -126,25 +155,23 @@ class ConvertTFXlaCallModuleOp return rewriter.notifyMatchFailure(op, error_msg); } } - mlir::Builder stablehlo_builder(stablehlo_module_op.get().getContext()); + Builder stablehlo_builder(stablehlo_module_op.get().getContext()); // Rename XlaCallModuleOp's functions to avoid naming conflicts. - for (auto func_op : - stablehlo_module_op.get().getOps()) { + for (auto func_op : stablehlo_module_op.get().getOps()) { const std::string new_func_name = CreateNewFuncName(func_op.getSymName(), parent_module_symbol_table); if (failed(stablehlo_module_symbol_table.replaceAllSymbolUses( func_op, stablehlo_builder.getStringAttr(new_func_name), stablehlo_module_op.get()))) { - return mlir::failure(); + return failure(); } - mlir::SymbolTable::setSymbolName(func_op, new_func_name); + SymbolTable::setSymbolName(func_op, new_func_name); } // Move all functions from XlaCallModuleOp's stablehlo module, to parent // module. Also marks the stablehlo module entry function as private. - mlir::func::FuncOp main_fn; - for (auto func_op : - stablehlo_module_op.get().getOps()) { - mlir::func::FuncOp cloned_func_op = func_op.clone(); + func::FuncOp main_fn; + for (auto func_op : stablehlo_module_op.get().getOps()) { + func::FuncOp cloned_func_op = func_op.clone(); if (IsShloMainFuncOp(cloned_func_op)) { main_fn = cloned_func_op; } @@ -153,11 +180,20 @@ class ConvertTFXlaCallModuleOp parent_module_symbol_table.insert(cloned_func_op); } + // When the `XlaCallModuleOp`'s callee accepts a platform index argument, + // remove it. This is because when converted to `CallOp` there will be a + // mismatch btw. the number of arguments passed and number of parameters + // accepted (the platform index argument is an extra argument that is not + // expressed by the operands of XlaCallModuleOp). + if (ContainsPlatformIndexArg(op)) { + RemovePlatformIndexArg(getContext(), main_fn); + } + // The stablehlo module main function's input tensor types might be // different from the XlaCallModuleOp's input tensor types. For example, // The XlaCallModuleOp's input is tensor<*xf32> while the function's // argument type is tensor<1x2f32>. - llvm::SmallVector casted_operands; + SmallVector casted_operands; casted_operands.reserve(main_fn.getNumArguments()); for (const auto &operand_and_type : zip(op.getOperands(), main_fn.getFunctionType().getInputs())) { @@ -176,7 +212,7 @@ class ConvertTFXlaCallModuleOp casted_operands); rewriter.replaceOp(op, call->getResults()); - return mlir::success(); + return success(); } // Creates a new function name to avoid collision. The naming scheme is @@ -206,9 +242,9 @@ class TFXlaCallModuleOpToStablehloPass StringRef getDescription() const final { return "Legalize TF_XlaCallModule Op to stablehlo"; } - void getDependentDialects(::mlir::DialectRegistry ®istry) const override { - registry.insert(); + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); } void runOnOperation() override { @@ -222,7 +258,7 @@ class TFXlaCallModuleOpToStablehloPass } }; -std::unique_ptr> +std::unique_ptr> CreateLegalizeTFXlaCallModuleToStablehloPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h index 211019b70524f5..a91584b0ff050e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h @@ -47,7 +47,7 @@ CreateComposeUniformQuantizedTypePass(); // quantized typed tensors and converts them to equivalent ops in the TFLite // dialect. std::unique_ptr> -CreateUniformQuantizedStablehloToTflPass(); +CreateUniformQuantizedStableHloToTflPass(); // Create a pass that legalizes MHLO to TF dialect. std::unique_ptr> CreateLegalizeHloToTfPass(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td index ca49787a715bd5..be4c936602eafb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td @@ -58,10 +58,10 @@ def ComposeUniformQuantizedTypePass : Pass<"compose-uniform-quantized-type", "Mo ]; } -def UniformQuantizedStablehloToTflPass +def UniformQuantizedStableHloToTflPass : Pass<"uniform-quantized-stablehlo-to-tfl", "mlir::func::FuncOp"> { let summary = "Converts StableHLO ops using uniform quantized types to equivalent TFL ops."; - let constructor = "mlir::odml::CreateUniformQuantizedStablehloToTflPass()"; + let constructor = "mlir::odml::CreateUniformQuantizedStableHloToTflPass()"; let description = [{ Converts StableHLO ops that accept or return uniform quantized types to equivalent ops in the TFLite dialect. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc index 3bb9eddbfa5021..b120ca89c290d4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc @@ -135,6 +135,16 @@ void StablehloToTflPass::runOnOperation() { continue; } + if (attr.isa<::mlir::DenseBoolArrayAttr>()) { + auto array_attr = attr.dyn_cast(); + auto start = fbb->StartVector(key); + for (auto bool_value : array_attr.asArrayRef()) { + fbb->Add(bool_value); + } + fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); + continue; + } + if (attr.isa<::mlir::StringAttr>()) { fbb->String(key, attr.dyn_cast().data()); continue; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc index 5e4f79f18ce503..8e22b343d7f3d0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc @@ -44,6 +44,51 @@ limitations under the License. namespace mlir { namespace odml { +static bool isDenseI64Array(llvm::StringRef op_name, + llvm::StringRef field_name) { + if (op_name == "stablehlo.broadcast" && field_name == "broadcast_sizes") + return true; + if (op_name == "stablehlo.broadcast_in_dim" && + field_name == "broadcast_dimensions") + return true; + if ((op_name == "stablehlo.convolution" || + op_name == "stablehlo.dynamic_conv") && + (field_name == "window_strides" || field_name == "lhs_dilation" || + field_name == "rhs_dilation")) + return true; + if (op_name == "stablehlo.dynamic_broadcast_in_dim" && + (field_name == "broadcast_dimensions" || + field_name == "known_expanding_dimensions" || + field_name == "known_nonexpanding_dimensions")) + return true; + if ((op_name == "stablehlo.dynamic_slice" || op_name == "stablehlo.gather") && + field_name == "slice_sizes") + return true; + if (op_name == "stablehlo.fft" && field_name == "fft_length") return true; + if ((op_name == "stablehlo.map" || op_name == "stablehlo.reduce" || + op_name == "stablehlo.reverse") && + field_name == "dimensions") + return true; + if (op_name == "stablehlo.pad" && + (field_name == "edge_padding_low" || field_name == "edge_padding_high" || + field_name == "interior_padding")) + return true; + if (op_name == "stablehlo.reduce_window" && + (field_name == "window_dimensions" || field_name == "window_strides" || + field_name == "base_dilations" || field_name == "window_dilations")) + return true; + if (op_name == "stablehlo.select_and_scatter" && + (field_name == "window_dimensions" || field_name == "window_strides")) + return true; + if (op_name == "stablehlo.slice" && + (field_name == "start_indices" || field_name == "limit_indices" || + field_name == "strides")) + return true; + if (op_name == "stablehlo.transpose" && field_name == "permutation") + return true; + return false; +} + class TflToStablehloPass : public mlir::PassWrapper> { @@ -90,6 +135,16 @@ class TflToStablehloPass attrs.push_back(named_attr); break; } + case flexbuffers::FBT_VECTOR_BOOL: { + llvm::SmallVector vec; + const auto& vector = value.AsTypedVector(); + for (size_t i = 0; i < vector.size(); i++) { + vec.push_back(vector[i].AsBool()); + } + attrs.push_back( + builder->getNamedAttr(key, builder->getDenseBoolArrayAttr(vec))); + break; + } case flexbuffers::FBT_VECTOR_INT: { const auto& vector = value.AsTypedVector(); std::vector vec; @@ -104,11 +159,7 @@ class TflToStablehloPass shape.push_back(vec.size()); } Attribute value; - if (op_name == "stablehlo.broadcast" || - op_name == "stablehlo.dynamic_slice" || - op_name == "stablehlo.fft" || op_name == "stablehlo.pad" || - op_name == "stablehlo.reverse" || op_name == "stablehlo.slice" || - op_name == "stablehlo.transpose") { + if (isDenseI64Array(op_name, key)) { value = builder->getDenseI64ArrayAttr(vec); } else { RankedTensorType ty = tensorflow::GetTypeFromTFTensorShape( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index df3a5f62e8ff59..678c136d5a9571 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -38,7 +38,7 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, // if the input is a call_xla_module, then unwrap the content pm.addPass(mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); - // TODO(b/230572023): Consider improving shape inference for While op instead + // TODO: b/230572023 - Consider improving shape inference for While op instead // of dropping the attribute. This need not be correct for models not trained // on TPU. @@ -85,11 +85,18 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, } } -void AddMhloOptimizationPasses(OpPassManager& pm) { +void AddMhloOptimizationPasses(OpPassManager& pm, + const bool enable_stablehlo_quantizer) { // Rewrites some patterns for better performance. pm.addNestedPass(createUnfuseBatchNormPass()); pm.addNestedPass(createFuseConvolutionPass()); pm.addNestedPass(createOptimizePass()); + // Conditionally enable below pass because this causes unfused convolutions + // described in b/293149194. This problem is not replicated in + // StableHLO Quantizer. + if (enable_stablehlo_quantizer) { + pm.addNestedPass(createFoldBroadcastPass()); + } // Rewrites legacy StableHLO ops. pm.addNestedPass(mhlo::createLegalizeEinsumToDotGeneralPass()); @@ -109,8 +116,8 @@ void AddStablehloOptimizationPasses(OpPassManager& pm) { // StableHLO -> MHLO legalization. pm.addPass(mhlo::createStablehloLegalizeToHloPass()); - AddMhloOptimizationPasses(pm); - // TODO(b/293149194) Add `createFoldBroadcastPass` back to + AddMhloOptimizationPasses(pm, /*enable_stablehlo_quantizer=*/false); + // TODO: b/293149194 - Add `createFoldBroadcastPass` back to // `AddMhloOptimizationPasses` pm.addNestedPass(createFoldBroadcastPass()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h index 02d5d527901c9a..dc23a5b30f0b3c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h @@ -36,7 +36,8 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, void AddStablehloOptimizationPasses(OpPassManager& pm); // Adds all the backend-agonstic stableHLO optimization passes -void AddMhloOptimizationPasses(OpPassManager& pm); +void AddMhloOptimizationPasses(OpPassManager& pm, + bool enable_stablehlo_quantizer); } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index dffb51ecd3f800..429c6b820623dd 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -14,17 +14,21 @@ limitations under the License. ==============================================================================*/ #include #include -#include #include +#include #include #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // NOLINT: Required to register quantization dialect. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -37,10 +41,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "stablehlo/dialect/Base.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #define DEBUG_TYPE "uniform-quantized-stablehlo-to-tfl" @@ -51,24 +55,38 @@ namespace { // TODO: b/311029361: Add e2e test for verifying this legalization once // StableHLO Quantizer API migration is complete. +using ::mlir::quant::CastI64ArrayToI32; +using ::mlir::quant::CastI64ToI32; +using ::mlir::quant::CreateI32F32UniformQuantizedPerAxisType; +using ::mlir::quant::CreateI32F32UniformQuantizedType; +using ::mlir::quant::CreateI8F32UniformQuantizedPerAxisType; +using ::mlir::quant::CreateI8F32UniformQuantizedType; +using ::mlir::quant::FindUserOfType; +using ::mlir::quant::IsI32F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI32F32UniformQuantizedType; using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI8F32UniformQuantizedType; +using ::mlir::quant::IsOpFullyQuantized; +using ::mlir::quant::IsQuantizedTensorType; using ::mlir::quant::IsSupportedByTfliteQuantizeOrDequantizeOps; using ::mlir::quant::QuantizedType; using ::mlir::quant::UniformQuantizedPerAxisType; using ::mlir::quant::UniformQuantizedType; +const char* kPaddingSame = "SAME"; +const char* kPaddingValid = "VALID"; + #define GEN_PASS_DEF_UNIFORMQUANTIZEDSTABLEHLOTOTFLPASS #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" -class UniformQuantizedStablehloToTflPass - : public impl::UniformQuantizedStablehloToTflPassBase< - UniformQuantizedStablehloToTflPass> { +class UniformQuantizedStableHloToTflPass + : public impl::UniformQuantizedStableHloToTflPassBase< + UniformQuantizedStableHloToTflPass> { private: void runOnOperation() override; }; +// TODO: b/323645515 - Refactor reference functions. // Bias scales for matmul-like ops should be input scale * filter scale. Here it // is assumed that the input is per-tensor quantized and filter is per-channel // quantized. @@ -94,7 +112,7 @@ double GetBiasScale(const double input_scale, const double filter_scale) { // whereas `tfl.fully_connected` accepts an OI format. TFL::QConstOp CreateTflConstOpForFilter( stablehlo::ConstantOp filter_constant_op, PatternRewriter& rewriter, - bool is_per_axis) { + bool is_per_channel) { const auto filter_values = filter_constant_op.getValue() .cast() .getValues(); @@ -123,35 +141,27 @@ TFL::QConstOp CreateTflConstOpForFilter( Type new_filter_quantized_type; - if (is_per_axis) { + if (is_per_channel) { auto filter_quantized_type = filter_constant_op.getResult() .getType() .cast() .getElementType() .cast(); - - new_filter_quantized_type = UniformQuantizedPerAxisType::getChecked( - filter_constant_op.getLoc(), /*flags=*/true, - /*storageType=*/filter_quantized_type.getStorageType(), - /*expressedType=*/filter_quantized_type.getExpressedType(), - /*scales=*/filter_quantized_type.getScales(), - /*zeroPoints=*/filter_quantized_type.getZeroPoints(), - /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( + filter_constant_op->getLoc(), *rewriter.getContext(), + filter_quantized_type.getScales(), + filter_quantized_type.getZeroPoints(), + /*quantization_dimension=*/0, /*narrow_range=*/true); } else { auto filter_quantized_type = filter_constant_op.getResult() .getType() .cast() .getElementType() .cast(); - new_filter_quantized_type = UniformQuantizedType::getChecked( - filter_constant_op.getLoc(), /*flags=*/true, - /*storageType=*/filter_quantized_type.getStorageType(), - /*expressedType=*/filter_quantized_type.getExpressedType(), - /*scale=*/filter_quantized_type.getScale(), - /*zeroPoint=*/filter_quantized_type.getZeroPoint(), - /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + new_filter_quantized_type = CreateI8F32UniformQuantizedType( + filter_constant_op->getLoc(), *rewriter.getContext(), + filter_quantized_type.getScale(), filter_quantized_type.getZeroPoint(), + /*narrow_range=*/true); } // Required because the quantized dimension is changed from 3 -> 0. @@ -172,17 +182,16 @@ TFL::QConstOp CreateTflConstOpForFilter( // transformation). The quantization scale for the bias is input scale * // filter scale. `filter_const_op` is used to retrieve the filter scales and // the size of the bias constant. -// TODO - b/309896242: Support bias fusion legalization. -TFL::QConstOp CreateTflConstOpForDummyBias(const Location loc, - const double input_scale, - TFL::QConstOp filter_const_op, - PatternRewriter& rewriter, - bool is_per_axis) { +// TODO - b/309896242: Support bias fusion legalization and spatial dimension +// check when `stride` is not 1. +TFL::QConstOp CreateTflConstOpForDummyBias( + const Location loc, const double input_scale, TFL::QConstOp filter_const_op, + PatternRewriter& rewriter, bool is_per_channel, MLIRContext& ctx) { const ArrayRef filter_shape = filter_const_op.getResult().getType().getShape(); Type bias_quantized_type; - if (is_per_axis) { + if (is_per_channel) { const auto filter_quantized_element_type = filter_const_op.getResult() .getType() @@ -191,13 +200,11 @@ TFL::QConstOp CreateTflConstOpForDummyBias(const Location loc, // The storage type is i32 for bias, which is the precision used for // accumulation. - bias_quantized_type = UniformQuantizedPerAxisType::getChecked( - loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), - /*expressedType=*/rewriter.getF32Type(), /*scales=*/ + bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( + loc, ctx, GetBiasScales(input_scale, filter_quantized_element_type.getScales()), - /*zeroPoints=*/filter_quantized_element_type.getZeroPoints(), - /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + filter_quantized_element_type.getZeroPoints(), + /*quantization_dimension=*/0); } else { const auto filter_quantized_element_type = filter_const_op.getResult() @@ -207,13 +214,10 @@ TFL::QConstOp CreateTflConstOpForDummyBias(const Location loc, // The storage type is i32 for bias, which is the precision used for // accumulation. - bias_quantized_type = UniformQuantizedType::getChecked( - loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), - /*expressedType=*/rewriter.getF32Type(), /*scale=*/ + bias_quantized_type = CreateI32F32UniformQuantizedType( + loc, ctx, GetBiasScale(input_scale, filter_quantized_element_type.getScale()), - /*zeroPoint=*/filter_quantized_element_type.getZeroPoint(), - /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + filter_quantized_element_type.getZeroPoint()); } SmallVector bias_shape = {filter_shape[0]}; @@ -230,6 +234,7 @@ TFL::QConstOp CreateTflConstOpForDummyBias(const Location loc, } // stablehlo.uniform_quantize -> tfl.quantize +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteUniformQuantizeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -242,7 +247,8 @@ class RewriteUniformQuantizeOp const Type input_element_type = op.getOperand().getType().cast().getElementType(); if (!(input_element_type.isa() || - IsI32F32UniformQuantizedType(input_element_type))) { + IsI32F32UniformQuantizedType(input_element_type) || + IsI32F32UniformQuantizedPerAxisType(input_element_type))) { LLVM_DEBUG(llvm::dbgs() << "Uniform quantize op's input should be a " "float type or int32. Got: " << input_element_type << ".\n"); @@ -322,93 +328,472 @@ class RewriteUniformDequantizeOp } }; -// Rewrites `stablehlo.convolution` -> `tfl.conv_2d` when it accepts uniform -// quantized tensors. +// Rewrites `stablehlo.dot_general` to `tfl.fully_connected` or +// `tfl.batch_matmul` when it accepts uniform quantized tensors. // -// Conditions for the conversion: -// * Input and output tensors are per-tensor uniform quantized (i8->f32) +// StableHLO Quantizer output: +// * input: per-tensor qi8 +// * filter: per-tensor qi8 +// * output: per-tensor qi32 +// JAX Quantizer output: +// * input: per-tensor qi8 +// * filter: per-channel qi8 +// * output: per-tensor qi8 +// +// Conditions for the `tfl.batch_matmul` conversion: +// * size(batching_dimensions) <= 3 (TFLite support restriction) +// * size(contracting_dimensions) = 1 +// * Input tensors are per-tensor uniform quantized (i8->f32) +// tensors (full integer) with shape [..., r_x, c_x] or [..., c_x, r_x]. +// * The filter tensor is a per-tensor uniform quantized (i8->f32) tensor +// (constant or activation) with shape [..., r_y, c_y] or [..., c_y, r_y]. +// * Output tensors are per-tensor uniform quantized (i8->f32) or +// per-channel uniform quantized (i32->f32) tensors. +// +// Conditions for `tfl.fully_connected` conversion: +// * Input tensors are per-tensor uniform quantized (i8->f32) // tensors. -// * The filter tensor is constant a per-channel uniform quantized (i8->f32) -// tensor. -// * Convolution is a 2D convolution op and both the input's and filter's -// shape is 4 dimensional. -// * The filter tensor's format is `[0, 1, i, o]`. -// * Not a depthwise convolution. -// * Does not consider bias add fusion. -// TODO: b/294771704 - Support bias quantization. -class RewriteUpstreamQuantizedConvolutionOp - : public OpRewritePattern { +// * The filter tensor is constant a per-tensor uniform quantized (i8->f32) +// tensor. The quantization dimension should be 1 (the non-contracting +// dimension). +// * Output tensors are per-tensor uniform quantized (i8->f32) or +// per-channel uniform quantized (i32->f32) tensors. +// * The input tensor's rank is either 2 or 3. The last dimension of the input +// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. +// * The filter tensor's rank is 2. The contracting dimension should be the +// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. +// TODO: b/309896242 - Add support for fused op case. Add support for +// per-channel quantization. +// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands +// is not specified in the StableHLO dialect. Update the spec to allow this. +class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + public: - using OpRewritePattern::OpRewritePattern; + LogicalResult match(stablehlo::DotGeneralOp op) const override { + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = + op.getDotDimensionNumbers(); + const bool is_batch_matmul = + !dot_dimension_nums.getLhsBatchingDimensions().empty(); + const bool has_i32_output = IsI32F32UniformQuantizedType( + op.getResult().getType().cast().getElementType()); - static LogicalResult MatchInput(Value input) { - auto input_type = input.getType().cast(); - if (input_type.getRank() != 4) { - LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " - "Expected input rank of 4. Got: " - << input_type.getRank() << ".\n"); + if (failed(MatchInputDotGeneralCommonPattern(op.getLhs()))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match input for quantized dot_general.\n"); + return failure(); + } + if (failed(MatchFilterCommonPattern(op.getRhs()))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match filter for quantized dot_general.\n"); + return failure(); + } + if (failed(MatchOutput(op.getResult(), has_i32_output))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match output for quantized dot_general.\n"); return failure(); } - if (const auto input_element_type = input_type.getElementType(); - !IsI8F32UniformQuantizedType(input_element_type)) { + if (is_batch_matmul) { + return MatchDotGeneralToTflBatchMatmulOp(op, dot_dimension_nums, + has_i32_output); + } + return MatchDotGeneralToTflFullyConnectedOp(op, dot_dimension_nums, + has_i32_output); + } + + void rewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + const bool has_i32_output = IsI32F32UniformQuantizedType( + op.getResult().getType().cast().getElementType()); + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = + op.getDotDimensionNumbers(); + const bool is_batch_matmul = + !dot_dimension_nums.getLhsBatchingDimensions().empty(); + + if (is_batch_matmul) { + RewriteDotGeneralToTflBatchMatmulOp(op, rewriter, dot_dimension_nums, + has_i32_output); + } else { + RewriteDotGeneralToTflFullyConnectedOp(op, rewriter, dot_dimension_nums, + has_i32_output); + } + } + + private: + static LogicalResult MatchDotGeneralToTflBatchMatmulOp( + stablehlo::DotGeneralOp op, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + const int num_lhs_batching_dims = + dot_dimension_nums.getLhsBatchingDimensions().size(); + const int num_lhs_contracting_dims = + dot_dimension_nums.getLhsContractingDimensions().size(); + if (num_lhs_batching_dims > 3) { + LLVM_DEBUG(llvm::dbgs() << "Failed to match batching dimension for " + "quantized dot_general.\n"); + return failure(); + } + // Checking one side is enough since + // (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions). + if (num_lhs_contracting_dims != 1) { + // Check one side is enough since + // (C2) size(lhs_contracting_dimensions) = + // size(rhs_contracting_dimensions). + LLVM_DEBUG(llvm::dbgs() << "Failed to match contract dimension for " + "quantized dot_general.\n"); + return failure(); + } + const auto input_type = op.getLhs().getType().cast(); + const int input_rank = input_type.getRank(); + const auto input_contracting_dim = + dot_dimension_nums.getLhsContractingDimensions()[0]; + if ((input_contracting_dim != input_rank - 1) && + (input_contracting_dim != input_rank - 2)) { LLVM_DEBUG(llvm::dbgs() - << "Expected an i8->f32 uniform quantized type. Got: " - << input_element_type << ".\n"); + << "Failed to match input contracting dimensions.\n"); return failure(); } + const auto filter_type = op.getRhs().getType().cast(); + const Type filter_element_type = filter_type.getElementType(); + if (!IsI8F32UniformQuantizedType(filter_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a per-tensor uniform " + "quantized (i8->f32) weight for dot_general. Got: " + << filter_type << "\n"); + return failure(); + } + const int rhs_rank = filter_type.cast().getRank(); + const auto rhs_contracting_dim = + dot_dimension_nums.getRhsContractingDimensions()[0]; + if ((rhs_contracting_dim != rhs_rank - 1) && + (rhs_contracting_dim != rhs_rank - 2)) { + LLVM_DEBUG(llvm::dbgs() + << "Not supported rhs contracting dim for dot_general.\n"); + return failure(); + } return success(); } - static LogicalResult MatchFilter(Value filter) { - auto filter_type = filter.getType().cast(); - if (filter_type.getRank() != 4) { - LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " - "Expected filter rank of 4. Got: " - << filter_type.getRank() << ".\n"); + static LogicalResult MatchDotGeneralToTflFullyConnectedOp( + stablehlo::DotGeneralOp op, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + const int num_lhs_contracting_dims = + dot_dimension_nums.getLhsContractingDimensions().size(); + const int num_rhs_contracting_dims = + dot_dimension_nums.getRhsContractingDimensions().size(); + if (num_lhs_contracting_dims != 1 || num_rhs_contracting_dims != 1) { + LLVM_DEBUG(llvm::dbgs() + << "Expected number of contracting dimensions to be 1. Got: " + << num_rhs_contracting_dims << ".\n"); return failure(); } - const Type filter_element_type = filter_type.getElementType(); - if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { - LLVM_DEBUG( - llvm::dbgs() - << "Expected a per-channel uniform quantized (i8->f32) type. Got: " - << filter_element_type << "\n"); + const auto input_type = op.getLhs().getType().cast(); + if (!(input_type.getRank() == 2 || input_type.getRank() == 3)) { + LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " + << input_type << ".\n"); return failure(); } - if (filter_element_type.cast() - .getQuantizedDimension() != 3) { - LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 3. Got: " - << filter_element_type << "\n"); + const auto filter_type = op.getRhs().getType().cast(); + if (filter_type.getRank() != 2) { + LLVM_DEBUG(llvm::dbgs() + << "Filter tensor expected to have a tensor rank of 2. Got: " + << filter_type << ".\n"); return failure(); } + if (has_i32_output) { + if (!IsI8F32UniformQuantizedType(filter_type.getElementType())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a per-channel uniform quantized " + "(i8->f32) type. Got: " + << filter_type.getElementType() << "\n"); + return failure(); + } + } else { + if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a per-channel uniform quantized " + "(i8->f32) type. Got: " + << filter_type.getElementType() << "\n"); + return failure(); + } + } - if (Operation* filter_op = filter.getDefiningOp(); - filter_op == nullptr || !isa(filter_op)) { - LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); + // If the op has a fusible bias, make sure the bias is a constant. + if (auto add_op = FindUserOfType(op); + add_op != nullptr && + !isa(add_op->getOperand(1).getDefiningOp())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a `stablehlo.constant` as the " + << "rhs of `stablehlo.add`.\n"); + } + + return success(); + } + + static LogicalResult MatchInputDotGeneralCommonPattern(const Value input) { + const auto input_type = input.getType().cast(); + if (const auto input_element_type = input_type.getElementType(); + !IsI8F32UniformQuantizedType(input_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected an i8->f32 uniform quantized type. Got: " + << input_element_type << ".\n"); + return failure(); + } + + if (!input_type.hasRank()) { + LLVM_DEBUG(llvm::dbgs() << "Expected input_type to have rank.\n"); return failure(); } + return success(); + } + static LogicalResult MatchFilterCommonPattern(const Value filter) { + auto filter_type = filter.getType().cast(); + if (!filter_type.hasRank()) { + LLVM_DEBUG(llvm::dbgs() << "Expected rhs of dot_general has rank. Got: " + << filter.getType() << "\n"); + return failure(); + } return success(); } - static LogicalResult MatchOutput(Value output) { + static LogicalResult MatchOutput(const Value output, + const bool has_i32_output) { const Type output_element_type = output.getType().cast().getElementType(); + if (has_i32_output) { + if (!IsI32F32UniformQuantizedType(output_element_type)) { + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-tensor uniform quantized (i32->f32) type. Got: " + << output_element_type << ".\n"); + return failure(); + } + return success(); + } if (!IsI8F32UniformQuantizedType(output_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i8->f32) type. Got: " - << output_element_type << ".\n"); + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-tensor uniform quantized (i8->f32) type. Got: " + << output_element_type << ".\n"); return failure(); } - return success(); } + static void RewriteDotGeneralToTflBatchMatmulOp( + stablehlo::DotGeneralOp op, PatternRewriter& rewriter, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + const auto rhs_contracting_dims = + dot_dimension_nums.getRhsContractingDimensions(); + const auto lhs_contracting_dims = + dot_dimension_nums.getLhsContractingDimensions(); + + const Value rhs_value = op.getRhs(); + const Value lhs_value = op.getLhs(); + + Operation* rhs_op = rhs_value.getDefiningOp(); + auto filter_constant_op = dyn_cast_or_null(rhs_op); + + // Set to `nullptr` because this attribute only matters when the input is + // dynamic-range quantized. + const BoolAttr asymmetric_quantize_inputs = nullptr; + + const int lhs_rank = lhs_value.getType().cast().getRank(); + const BoolAttr adj_x = + (lhs_contracting_dims[0] == lhs_rank - 2 ? rewriter.getBoolAttr(true) + : rewriter.getBoolAttr(false)); + const int rhs_rank = rhs_value.getType().cast().getRank(); + const BoolAttr adj_y = + (rhs_contracting_dims[0] == rhs_rank - 1 ? rewriter.getBoolAttr(true) + : rewriter.getBoolAttr(false)); + + // Create BMM assuming rhs is activation. + auto tfl_batchmatmul_op = rewriter.create( + op.getLoc(), /*output=*/op.getResult().getType(), + /*input=*/lhs_value, + /*filter=*/rhs_value, adj_x, adj_y, asymmetric_quantize_inputs); + + // Update BMM if rhs is a constant. + if (filter_constant_op != nullptr) { + const auto rhs_uniform_quantized_type = + rhs_value.getType().cast(); + const auto rhs_constant_value_attr = + cast(filter_constant_op.getValue()); + auto rhs_constant_op = rewriter.create( + rhs_op->getLoc(), + /*output=*/TypeAttr::get(rhs_uniform_quantized_type), + rhs_constant_value_attr); + tfl_batchmatmul_op = rewriter.create( + op.getLoc(), /*output=*/op.getResult().getType(), + /*input=*/lhs_value, /*filter=*/rhs_constant_op.getResult(), adj_x, + adj_y, asymmetric_quantize_inputs); + } + + rewriter.replaceAllUsesWith(op.getResult(), tfl_batchmatmul_op.getResult()); + } + static void RewriteDotGeneralToTflFullyConnectedOp( + stablehlo::DotGeneralOp op, PatternRewriter& rewriter, + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, + const bool has_i32_output) { + const Value rhs_value = op.getRhs(); + const Value lhs_value = op.getLhs(); + + Operation* rhs_op = rhs_value.getDefiningOp(); + const auto filter_constant_op = + dyn_cast_or_null(rhs_op); + + // Set to `nullptr` because this attribute only matters when the input is + // dynamic-range quantized. + const BoolAttr asymmetric_quantize_inputs = nullptr; + + // Checks for `tfl.fully_connected` condition. + + // StableHLO Quantizer does not yet support per-channel quantization of + // dot_general. + const bool is_per_channel = !has_i32_output; + // Create the new filter constant - transpose filter value + // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for + // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas + // `tfl.fully_connected` accepts an OI format. + TFL::QConstOp new_filter_constant_op = + CreateTflConstOpForFilter(filter_constant_op, rewriter, is_per_channel); + + const double input_scale = lhs_value.getType() + .cast() + .getElementType() + .cast() + .getScale(); + TFL::QConstOp bias_tfl_op; + bool fuse_bias_constant = + FindUserOfType(op) && has_i32_output; + // Get the desired output type and extract any existing fusible bias + // as `TFL::QConstOp` so that it can be fused with TFL::FullyConnectedOp`. + TensorType output_type = GetOutputTypeAndOptionallyUpdateBias( + op, rewriter, &bias_tfl_op, has_i32_output, fuse_bias_constant); + + // If there is no explicit bias, create a dummy value filled with zeroes. + if (!fuse_bias_constant) { + bias_tfl_op = CreateTflConstOpForDummyBias( + op.getLoc(), input_scale, new_filter_constant_op, rewriter, + is_per_channel, *op.getContext()); + } + rewriter.replaceOpWithNewOp( + op, /*output=*/output_type, + /*input=*/lhs_value, + /*filter=*/new_filter_constant_op.getResult(), + /*bias=*/bias_tfl_op.getResult(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(false), + asymmetric_quantize_inputs); + } + + static TensorType GetOutputTypeAndOptionallyUpdateBias( + Operation* op, PatternRewriter& rewriter, TFL::QConstOp* bias_tfl_op, + const bool has_i32_output, const bool fuse_bias_constant) { + TensorType output_type; + if (has_i32_output) { + Operation* uniform_quantize_op; + if (fuse_bias_constant) { + Operation* add_op = FindUserOfType(op); + uniform_quantize_op = FindUserOfType(add_op); + auto filter_quantized_type = op->getOperand(1) + .getType() + .cast() + .getElementType() + .cast(); + double bias_scale = GetBiasScale( + /*input_scale=*/op->getOperand(0) + .getType() + .cast() + .getElementType() + .cast() + .getScale(), + /*filter_scale=*/filter_quantized_type.getScale()); + ArrayRef output_shape = + op->getResult(0).getType().cast().getShape(); + const SmallVector bias_shape = { + output_shape[output_shape.size() - 1]}; + auto bias_quantized_type = CreateI32F32UniformQuantizedType( + op->getLoc(), *op->getContext(), std::move(bias_scale), + op->getResult(0) + .getType() + .cast() + .getElementType() + .cast() + .getZeroPoint()); + Operation* stablehlo_bias_op = add_op->getOperand(1).getDefiningOp(); + auto bias_type = RankedTensorType::getChecked(op->getLoc(), bias_shape, + bias_quantized_type); + auto bias_value = cast( + cast(stablehlo_bias_op).getValue()); + + *bias_tfl_op = rewriter.create( + op->getLoc(), + /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); + } else { + uniform_quantize_op = FindUserOfType(op); + } + + auto result_quantized_type = uniform_quantize_op->getResult(0) + .getType() + .cast() + .getElementType() + .cast(); + auto new_result_quantized_type = CreateI8F32UniformQuantizedType( + uniform_quantize_op->getLoc(), *rewriter.getContext(), + result_quantized_type.getScale(), + result_quantized_type.getZeroPoint()); + output_type = op->getResult(0).getType().cast().clone( + new_result_quantized_type); + // Omit any bias and requantize ops as `tfl.fully_connected` outputs a + // fused `qi8` type. + FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); + } else { + output_type = op->getResult(0).getType().cast(); + } + return output_type; + } +}; + +// Rewrites `stablehlo.convolution` into fused `tfl.conv_2d`. +// If available, fuse bias and activation adjacent to `stablehlo.convolution`. +// This RewritePattern rewrites both the following into `tfl.conv_2d` op: +// +// StableHLO Quantizer output: +// * input: per-tensor qi8 +// * filter: per-channel qi8 (`quantization_dimension` = 3) +// * output: per-channel qi32 (`quantization_dimension` = 3) +// JAX Quantizer output: +// * input: per-tensor qi8 +// * filter: per-channel qi8 (`quantization_dimension` = 3) +// * output: per-tensor qi8 +// +// Conditions for the conversion: +// * Input tensors are per-tensor uniform quantized (i8->f32) +// tensors. +// * The filter tensor is constant a per-channel uniform quantized (i8->f32) +// tensor. +// * Output tensors are per-tensor uniform quantized (i8->f32) or +// per-channel uniform quantized (i32->f32) tensors. +// * Convolution is a 2D convolution op and both the input's and filter's +// shape is 4 dimensional. +// * The filter tensor's format is `[0, 1, i, o]`. +// * Not a depthwise convolution. +class RewriteQuantizedConvolutionOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; LogicalResult match(stablehlo::ConvolutionOp op) const override { + const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( + op.getResult().getType().cast().getElementType()); + const bool fuse_bias_constant = + FindUserOfType(op) && has_i32_output; stablehlo::ConvDimensionNumbersAttr dimension_numbers = op.getDimensionNumbers(); @@ -446,14 +831,39 @@ class RewriteUpstreamQuantizedConvolutionOp return failure(); } + // TODO: b/309896242 - Lift the assumptions on adjacent ops below + // as we cover more dynamic fused pattern legalization. + if (fuse_bias_constant) { + Operation* add_op = FindUserOfType(op); + if (add_op == nullptr) { + LLVM_DEBUG(llvm::dbgs() << "Failed to find AddOp for bias fusion.\n"); + return failure(); + } + Operation* broadcast_in_dim_op = add_op->getOperand(1).getDefiningOp(); + if (!isa(broadcast_in_dim_op)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to find broadcasted bias.\n"); + return failure(); + } + Operation* bias_const_op = + broadcast_in_dim_op->getOperand(0).getDefiningOp(); + if (!isa(bias_const_op)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to find bias constant.\n"); + return failure(); + } + } + return success(); } void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const override { + const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( + op.getResult().getType().cast().getElementType()); + stablehlo::ConvDimensionNumbersAttr dimension_numbers = + op.getDimensionNumbers(); + Value filter_value = op.getOperand(1); Operation* filter_op = filter_value.getDefiningOp(); - auto filter_uniform_quantized_type = filter_value.getType() .cast() @@ -467,15 +877,11 @@ class RewriteUpstreamQuantizedConvolutionOp // (https://github.com/tensorflow/tensorflow/blob/5430e5e238f868ce977df96ba89c9c1d31fbe8fa/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L933). // The quantized dimension should correspond to the output feature // dimension. - auto new_filter_quantized_type = UniformQuantizedPerAxisType::getChecked( - filter_op->getLoc(), /*flags=*/true, - /*storageType=*/filter_uniform_quantized_type.getStorageType(), - filter_uniform_quantized_type.getExpressedType(), + auto new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( + filter_op->getLoc(), *op.getContext(), filter_uniform_quantized_type.getScales(), filter_uniform_quantized_type.getZeroPoints(), - /*quantizedDimension=*/0, - filter_uniform_quantized_type.getStorageTypeMin(), - filter_uniform_quantized_type.getStorageTypeMax()); + /*quantization_dimension=*/0, /*narrow_range=*/true); auto filter_constant_value_attr = cast( cast(filter_value.getDefiningOp()).getValue()); @@ -495,48 +901,33 @@ class RewriteUpstreamQuantizedConvolutionOp filter_op->getLoc(), /*output=*/TypeAttr::get(new_filter_result_type), new_filter_value_attr); - SmallVector bias_scales = - GetBiasScales(/*input_scale=*/op.getOperand(0) - .getType() - .cast() - .getElementType() - .cast() - .getScale(), - /*filter_scales=*/new_filter_quantized_type.getScales()); + Operation* uniform_quantize_op; + const bool fuse_bias_constant = + FindUserOfType(op) && has_i32_output; + if (has_i32_output) { + if (fuse_bias_constant) { + Operation* add_op = FindUserOfType(op); + uniform_quantize_op = FindUserOfType(add_op); + } else { + uniform_quantize_op = FindUserOfType(op); + } + } - // Create a bias filled with zeros. Mimics the behavior of no bias add. const int64_t num_output_features = new_filter_result_type.getShape()[0]; const SmallVector bias_shape = {num_output_features}; - auto bias_quantized_type = UniformQuantizedPerAxisType::getChecked( - op.getLoc(), /*flags=*/true, - /*storageType=*/rewriter.getI32Type(), // i32 for bias - /*expressedType=*/rewriter.getF32Type(), - /*scales=*/std::move(bias_scales), - /*zeroPoints=*/new_filter_quantized_type.getZeroPoints(), // Zeros. - /*quantizedDimension=*/0, - /*storageTypeMin=*/std::numeric_limits::min(), - /*storageTypeMax=*/std::numeric_limits::max()); - auto bias_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, - bias_quantized_type); - - // Create a bias constant. It should have values of 0. - auto bias_value_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, - rewriter.getI32Type()); - auto bias_value = DenseIntElementsAttr::get( - bias_value_type, APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); - auto bias = rewriter.create( - op.getLoc(), /*output=*/TypeAttr::get(bias_type), - /*value=*/bias_value); + + TFL::QConstOp bias = GetBiasOp(op, rewriter, new_filter_result_type, + new_filter_quantized_type, bias_shape, + has_i32_output, fuse_bias_constant); // Determine the attributes for the TFL::Conv2DOp. - // TODO: b/294808863 - Use `padding = "SAME"` if the padding attribute - // matches the semantics. + Value input_value = op.getOperand(0); if (const DenseIntElementsAttr padding_attr = op.getPaddingAttr(); - !IsPaddingValid(padding_attr)) { + !HasProperPadding(op, dimension_numbers, padding_attr)) { // Add an extra tfl.pad_op if there are explicit padding values. This - // extra pad op will allow us to always set the `padding` attribute of the - // newly created tfl.conv_2d op as "VALID". + // extra pad op will allow us to always set the `padding` attribute of + // the newly created tfl.conv_2d op as "VALID". TFL::PadOp pad_op = CreateTflPadOp(op.getLoc(), padding_attr, input_value, rewriter); input_value = pad_op.getResult(); @@ -545,40 +936,130 @@ class RewriteUpstreamQuantizedConvolutionOp const auto [stride_h, stride_w] = GetStrides(op); const auto [dilation_h_factor, dilation_w_factor] = GetDilationFactors(op); - auto tfl_conv2d_op = rewriter.create( - op.getLoc(), /*output=*/op.getResult().getType(), /*input=*/input_value, + Type output_type; + if (has_i32_output) { + // StableHLO Quantizer outputs an i32 type. Rewrite to i8 type result + // to meet TFLite op requirement. + auto result_quantized_type = uniform_quantize_op->getResult(0) + .getType() + .cast() + .getElementType() + .cast(); + auto new_result_quantized_type = CreateI8F32UniformQuantizedType( + uniform_quantize_op->getLoc(), *rewriter.getContext(), + result_quantized_type.getScale(), + result_quantized_type.getZeroPoint()); + output_type = op.getResult().getType().cast().clone( + new_result_quantized_type); + // Omit any bias and requantize ops as `tfl.fully_connected` outputs a + // fused `qi8` type. + FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); + } else { + output_type = op.getResult().getType(); + } + rewriter.replaceOpWithNewOp( + // op result should be recasted to desired quantized type. + op, output_type, + /*input=*/input_value, /*filter=*/new_filter_constant_op, /*bias=*/bias.getResult(), /*dilation_h_factor=*/rewriter.getI32IntegerAttr(dilation_h_factor), /*dilation_w_factor=*/rewriter.getI32IntegerAttr(dilation_w_factor), /*fused_activation_function=*/rewriter.getStringAttr("NONE"), - /*padding=*/rewriter.getStringAttr("VALID"), + /*padding=*/ + rewriter.getStringAttr(UseSamePadding(op, dimension_numbers) + ? kPaddingSame + : kPaddingValid), /*stride_h=*/rewriter.getI32IntegerAttr(stride_h), /*stride_w=*/rewriter.getI32IntegerAttr(stride_w)); - - rewriter.replaceAllUsesWith(op.getResult(), tfl_conv2d_op.getResult()); - rewriter.eraseOp(op); } private: - // Create a `tfl.pad` op to apply explicit padding to the input tensor that - // correspond to the `padding` attribute from the `stablehlo.convolution` op. - TFL::PadOp CreateTflPadOp(Location loc, - const DenseIntElementsAttr& padding_attr, - Value input_value, - PatternRewriter& rewriter) const { - auto padding_values = padding_attr.getValues(); - // [[h_l, h_r], [w_l, w_r]]. - DCHECK_EQ(padding_attr.size(), 4); + static LogicalResult MatchInput(Value input) { + auto input_type = input.getType().cast(); + if (input_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected input rank of 4. Got: " + << input_type.getRank() << ".\n"); + return failure(); + } - // In StableHLO the padding attribute doesn't include the padding values for - // input and output feature dimensions (because they are 0 anyways). In - // TFLite, padding values for input and output feature dimensions should be - // explicitly set to 0s. Note that TFLite's input tensor is formatted as - // OHWI. The resulting pad values becomes: [[0, 0], [h_l, h_r], [w_l, w_r], + if (const auto input_element_type = input_type.getElementType(); + !IsI8F32UniformQuantizedType(input_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected an i8->f32 uniform quantized type. Got: " + << input_element_type << ".\n"); + return failure(); + } + + return success(); + } + + static LogicalResult MatchFilter(Value filter) { + auto filter_type = filter.getType().cast(); + if (filter_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected filter rank of 4. Got: " + << filter_type.getRank() << ".\n"); + return failure(); + } + + const Type filter_element_type = filter_type.getElementType(); + if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-channel uniform quantized (i8->f32) type. Got: " + << filter_element_type << "\n"); + return failure(); + } + + if (filter_element_type.cast() + .getQuantizedDimension() != 3) { + LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 3. Got: " + << filter_element_type << "\n"); + return failure(); + } + + if (Operation* filter_op = filter.getDefiningOp(); + filter_op == nullptr || !isa(filter_op)) { + LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); + return failure(); + } + return success(); + } + + static LogicalResult MatchOutput(Value output) { + const Type output_element_type = + output.getType().cast().getElementType(); + if (!IsI32F32UniformQuantizedPerAxisType(output_element_type) && + !IsI8F32UniformQuantizedType(output_element_type)) { + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-channel uniform quantized (i32->f32) type or " + << "per-tensor uniform quantized (i8->f32) type. Got: " + << output_element_type << ".\n"); + return failure(); + } + return success(); + } + // Create a `tfl.pad` op to apply explicit padding to the input tensor that + // correspond to the `padding` attribute from the `stablehlo.convolution` op. + TFL::PadOp CreateTflPadOp(Location loc, + const DenseIntElementsAttr& padding_attr, + Value input_value, + PatternRewriter& rewriter) const { + auto padding_values = padding_attr.getValues(); + // [[h_l, h_r], [w_l, w_r]]. + DCHECK_EQ(padding_attr.size(), 4); + + // In StableHLO the padding attribute doesn't include the padding values for + // input and output feature dimensions (because they are 0 anyways). In + // TFLite, padding values for input and output feature dimensions should be + // explicitly set to 0s. Note that TFLite's input tensor is formatted as + // OHWI. The resulting pad values becomes: [[0, 0], [h_l, h_r], [w_l, w_r], // [0, 0]] SmallVector tfl_pad_values = {0, 0}; // For output feature dim. for (const int64_t padding_value : padding_values) { - tfl_pad_values.push_back(static_cast(padding_value)); + tfl_pad_values.push_back(CastI64ToI32(padding_value).value()); } // For input feature dim. tfl_pad_values.push_back(0); @@ -634,7 +1115,7 @@ class RewriteUpstreamQuantizedConvolutionOp ArrayRef filter_shape = filter_value_attr.getShapedType().getShape(); SmallVector filter_constant_values; - for (const auto filter_val : filter_value_attr.getValues()) { + for (auto filter_val : filter_value_attr.getValues()) { filter_constant_values.push_back(filter_val); } @@ -659,8 +1140,8 @@ class RewriteUpstreamQuantizedConvolutionOp for (int k = 0; k < filter_shape[2]; ++k) { for (int l = 0; l < filter_shape[3]; ++l) { // [i][j][k][l] -> [l][i][j][k] - const int old_idx = get_array_idx(filter_shape, i, j, k, l); - const int new_idx = get_array_idx(new_filter_shape, l, i, j, k); + int old_idx = get_array_idx(filter_shape, i, j, k, l); + int new_idx = get_array_idx(new_filter_shape, l, i, j, k); new_filter_constant_values[new_idx] = filter_constant_values[old_idx]; @@ -679,23 +1160,44 @@ class RewriteUpstreamQuantizedConvolutionOp return new_filter_constant_value_attr; } - // Determines if the padding attribute corresponds to "VALID" + bool UseSamePadding( + Operation* op, + stablehlo::ConvDimensionNumbersAttr dimension_numbers) const { + // TODO: b/294808863 - Account for dynamic shapes. + const ArrayRef input_shape = + op->getOperand(0).getType().cast().getShape(); + const ArrayRef output_shape = + op->getResult(0).getType().cast().getShape(); + const ArrayRef input_spatial_dim_inds = + dimension_numbers.getInputSpatialDimensions(); + const ArrayRef output_spatial_dim_inds = + dimension_numbers.getOutputSpatialDimensions(); + return (input_shape[input_spatial_dim_inds[0]] == + output_shape[output_spatial_dim_inds[0]] && + input_shape[input_spatial_dim_inds[1]] == + output_shape[output_spatial_dim_inds[1]]); + } + + // Determines if the padding attribute corresponds to "VALID" or "SAME". + // If not, the input's shape should be adjusted with explicit `tfl.pad` op. // (https://www.tensorflow.org/api_docs/python/tf/nn). - bool IsPaddingValid(const DenseIntElementsAttr& padding_attr) const { + bool HasProperPadding(Operation* op, + stablehlo::ConvDimensionNumbersAttr dimension_numbers, + const DenseIntElementsAttr& padding_attr) const { // If padding_attr is empty, it defaults to splat 0s. - return !padding_attr || (padding_attr.isSplat() && - padding_attr.getSplatValue() == 0); + return UseSamePadding(op, dimension_numbers) || + (!padding_attr || (padding_attr.isSplat() && + padding_attr.getSplatValue() == 0)); } // Returns the stride amount for the height and width, respectively. std::pair GetStrides(stablehlo::ConvolutionOp op) const { - const Attribute window_strides_attr = op.getWindowStridesAttr(); + DenseI64ArrayAttr window_strides_attr = op.getWindowStridesAttr(); if (!window_strides_attr) { return {1, 1}; // Default values. } - const auto window_strides_attr_value = - hlo::getI64Array(window_strides_attr); + auto window_strides_attr_value = window_strides_attr.asArrayRef(); // It is guaranteed from the spec that it has two values: // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. return {window_strides_attr_value[0], window_strides_attr_value[1]}; @@ -704,580 +1206,510 @@ class RewriteUpstreamQuantizedConvolutionOp // Returns the dilation amount for the height and width, respectively. std::pair GetDilationFactors( stablehlo::ConvolutionOp op) const { - const Attribute lhs_dilation_attr = op.getLhsDilationAttr(); + DenseI64ArrayAttr lhs_dilation_attr = op.getLhsDilationAttr(); if (!lhs_dilation_attr) { return {1, 1}; // Default values. } - const auto lhs_dilation_attr_value = hlo::getI64Array(lhs_dilation_attr); + auto lhs_dilation_attr_value = lhs_dilation_attr.asArrayRef(); // It is guaranteed from the spec that it has two values: // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. return {lhs_dilation_attr_value[0], lhs_dilation_attr_value[1]}; } + + TFL::QConstOp GetBiasOp( + stablehlo::ConvolutionOp op, PatternRewriter& rewriter, + const RankedTensorType new_filter_result_type, + const UniformQuantizedPerAxisType new_filter_quantized_type, + const SmallVector bias_shape, const bool has_i32_output, + const bool fuse_bias_constant) const { + const SmallVector bias_scales = GetBiasScales( + /*input_scale=*/op.getOperand(0) + .getType() + .cast() + .getElementType() + .cast() + .getScale(), + /*filter_scales=*/new_filter_quantized_type.getScales()); + + const auto bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( + op.getLoc(), *op.getContext(), std::move(bias_scales), + new_filter_quantized_type.getZeroPoints(), + /*quantization_dimension=*/0); + const auto bias_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, + bias_quantized_type); + TFL::QConstOp bias; + if (fuse_bias_constant && has_i32_output) { + Operation* add_op = FindUserOfType(op); + // TODO: b/309896242 - Lift the assumptions on adjacent ops below + // as we cover more dynamic fused pattern legalization. + Operation* broadcast_in_dim_op = add_op->getOperand(1).getDefiningOp(); + Operation* bias_const_op = + broadcast_in_dim_op->getOperand(0).getDefiningOp(); + const ElementsAttr bias_constant_value = + cast(bias_const_op).getValue(); + bias = rewriter.create(op.getLoc(), + /*output=*/TypeAttr::get(bias_type), + /*value=*/bias_constant_value); + } else { + // Create a bias constant. It should have values of 0. + const auto bias_value_type = RankedTensorType::getChecked( + op.getLoc(), bias_shape, rewriter.getI32Type()); + // Create a bias filled with zeros. Mimics the behavior of no bias add. + const auto bias_value = DenseIntElementsAttr::get( + bias_value_type, + APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); + bias = rewriter.create(op.getLoc(), + /*output=*/TypeAttr::get(bias_type), + /*value=*/bias_value); + } + return bias; + } }; -// Rewrites full-integer quantized `stablehlo.dot_general` ->`tfl.batch_matmul` -// when it accepts uniform quantized tensors. -// -// Since transpose and reshape of quantized tensors are not natively supported -// at the moment, the conversion condition is relatively strict, following -// (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul-v3) -// -// Conditions for the conversion : -// * size(batching_dimensions) <= 3 (TFLite support restriction) -// * size(contracting_dimensions) = 1 -// * Input (lhs) and output tensors are per-tensor uniform quantized (i8->f32) -// tensors (full integer) with shape [..., r_x, c_x] or [..., c_x, r_x]. -// * The rhs tensor is a per-tensor uniform quantized (i8->f32) tensor -// (constant or activation) with shape [..., r_y, c_y] or [..., c_y, r_y]. -// -// TODO: b/293650675 - Relax the conversion condition to support dot_general in -// general. -class RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp - : public OpRewritePattern { +// Rewrites quantized stablehlo.transpose to tfl.transpose. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedTransposeOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - static LogicalResult MatchLhs( - Value lhs, stablehlo::DotDimensionNumbersAttr dimension_numbers) { - auto lhs_type = lhs.getType().cast(); - if (!IsI8F32UniformQuantizedType(lhs_type.getElementType())) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a per-tensor uniform " - "quantized (i8->f32) input for dot_general. Got: " - << lhs_type << "\n"); - return failure(); - } - if (!lhs_type.hasRank()) { - LLVM_DEBUG(llvm::dbgs() << "Expected lhs of dot_general has rank. Got: " - << lhs_type << "\n"); - return failure(); - } - const int lhs_rank = lhs_type.getRank(); - auto lhs_contracting_dim = - dimension_numbers.getLhsContractingDimensions()[0]; - if ((lhs_contracting_dim != lhs_rank - 1) && - (lhs_contracting_dim != lhs_rank - 2)) { - LLVM_DEBUG(llvm::dbgs() - << "Not supported lhs contracting dim for dot_general.\n"); - return failure(); - } - return success(); + LogicalResult match(stablehlo::TransposeOp op) const override { + return success(IsOpFullyQuantized(op)); } - static LogicalResult MatchRhs( - Value rhs, stablehlo::DotDimensionNumbersAttr dimension_numbers) { - if (!rhs.getType().cast().hasRank()) { - LLVM_DEBUG(llvm::dbgs() << "Expected rhs of dot_general has rank. Got: " - << rhs.getType() << "\n"); - return failure(); - } - const int rhs_rank = rhs.getType().cast().getRank(); - auto rhs_contracting_dim = - dimension_numbers.getRhsContractingDimensions()[0]; - if ((rhs_contracting_dim != rhs_rank - 1) && - (rhs_contracting_dim != rhs_rank - 2)) { - LLVM_DEBUG(llvm::dbgs() - << "Not supported rhs contracting dim for dot_general.\n"); - return failure(); - } - - auto rhs_type = rhs.getType().cast(); - if (!IsI8F32UniformQuantizedType(rhs_type.getElementType())) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a per-tensor uniform " - "quantized (i8->f32) weight for dot_general. Got: " - << rhs_type << "\n"); - return failure(); - } - return success(); + void rewrite(stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { + auto operand_type = op.getOperand().getType().cast(); + const int64_t rank = operand_type.getRank(); + ArrayRef shape(rank); + TensorType permutation_type = + operand_type.cloneWith(shape, rewriter.getI32Type()); + // Cast permutation attribute from i64 to i32 as they are required to be i32 + // in TFLite. + SmallVector permutation_i32 = + CastI64ArrayToI32(op.getPermutation()).value(); + auto permutation_attr = + DenseIntElementsAttr::get(permutation_type, permutation_i32); + auto permutation = + rewriter.create(op.getLoc(), permutation_attr); + rewriter.replaceOpWithNewOp(op, op.getOperand(), + permutation); } +}; - static LogicalResult MatchOutput( - Value output, stablehlo::DotDimensionNumbersAttr dimension_numbers) { - auto output_type = output.getType().cast(); - if (!IsI8F32UniformQuantizedType(output_type.getElementType())) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a per-tensor uniform " - "quantized (i8->f32) output for dot_general. Got: " - << output_type << "\n"); - return failure(); - } - return success(); +// Rewrites quantized stablehlo.reshape to tfl.reshape. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedReshapeOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::ReshapeOp op) const override { + return success(IsOpFullyQuantized(op)); } - LogicalResult match(stablehlo::DotGeneralOp op) const override { - stablehlo::DotDimensionNumbersAttr dimension_numbers = - op.getDotDimensionNumbers(); + void rewrite(stablehlo::ReshapeOp op, + PatternRewriter& rewriter) const override { + auto result_type = op->getResult(0).getType().cast(); + // Cast result shapes from i64 to i32 as they are required to be i32 in + // TFLite. + SmallVector shape_i32 = + CastI64ArrayToI32(result_type.getShape()).value(); + + const int64_t shape_length = shape_i32.size(); + ArrayRef shape(shape_length); + TensorType shape_type = result_type.cloneWith(shape, rewriter.getI32Type()); + auto shape_attr = DenseIntElementsAttr::get(shape_type, shape_i32); + auto new_shape = + rewriter.create(op.getLoc(), shape_attr); + rewriter.replaceOpWithNewOp(op, op.getOperand(), new_shape); + } +}; - // Check one side is enough since - // (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - if (dimension_numbers.getLhsBatchingDimensions().size() > 3) { - LLVM_DEBUG( - llvm::dbgs() - << "Failed to match batch dimention for quantized dot_general.\n"); - return failure(); - } - // Check one side is enough since - // (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - if (dimension_numbers.getLhsContractingDimensions().size() != 1) { - LLVM_DEBUG( - llvm::dbgs() - << "Failed to match contract dimention for quantized dot_general.\n"); - return failure(); - } +// Rewrites quantized stablehlo.select to tfl.select_v2. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedSelectOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; - if (failed(MatchLhs(op.getLhs(), dimension_numbers))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match input for quantized dot_general.\n"); + LogicalResult match(stablehlo::SelectOp op) const override { + if (!IsQuantizedTensorType(op.getOperand(1).getType())) { return failure(); } - if (failed(MatchRhs(op.getRhs(), dimension_numbers))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match weight for quantized dot_general.\n"); + if (!IsQuantizedTensorType(op.getOperand(2).getType())) { return failure(); } - - if (failed(MatchOutput(op.getResult(), dimension_numbers))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match output for quantized dot_general.\n"); + if (!IsQuantizedTensorType(op.getResult().getType())) { return failure(); } - return success(); } - void rewrite(stablehlo::DotGeneralOp op, + void rewrite(stablehlo::SelectOp op, PatternRewriter& rewriter) const override { - Value rhs_value = op.getRhs(); - Operation* rhs_op = rhs_value.getDefiningOp(); - - stablehlo::DotDimensionNumbersAttr dimension_numbers = - op.getDotDimensionNumbers(); - Value input_value = op.getLhs(); - const int lhs_rank = input_value.getType().cast().getRank(); - auto lhs_contracting_dim = - dimension_numbers.getLhsContractingDimensions()[0]; - BoolAttr adj_x = - (lhs_contracting_dim == lhs_rank - 2 ? rewriter.getBoolAttr(true) - : rewriter.getBoolAttr(false)); - auto rhs_contracting_dim = - dimension_numbers.getRhsContractingDimensions()[0]; - const int rhs_rank = rhs_value.getType().cast().getRank(); - BoolAttr adj_y = - (rhs_contracting_dim == rhs_rank - 1 ? rewriter.getBoolAttr(true) - : rewriter.getBoolAttr(false)); - - // Set to `nullptr` because this attribute only matters when the input is - // dynamic-range quantized. - BoolAttr asymmetric_quantize_inputs = nullptr; - - // Create BMM assuming rhs is activation. - auto tfl_batchmatmul_op = rewriter.create( - op.getLoc(), /*output=*/op.getResult().getType(), /*input=*/input_value, - /*filter=*/rhs_value, adj_x, adj_y, asymmetric_quantize_inputs); - - // Update BMM if rhs is a constant. - auto const_rhs = dyn_cast_or_null(rhs_op); - if (const_rhs) { - auto rhs_uniform_quantized_type = rhs_value.getType().cast(); - auto rhs_constant_value_attr = - cast(const_rhs.getValue()); - auto rhs_constant_op = rewriter.create( - rhs_op->getLoc(), - /*output=*/TypeAttr::get(rhs_uniform_quantized_type), - rhs_constant_value_attr); - tfl_batchmatmul_op = rewriter.create( - op.getLoc(), /*output=*/op.getResult().getType(), - /*input=*/input_value, /*filter=*/rhs_constant_op.getResult(), adj_x, - adj_y, asymmetric_quantize_inputs); - } - - rewriter.replaceAllUsesWith(op.getResult(), tfl_batchmatmul_op.getResult()); + Value pred = op.getOperand(0); + Value on_true = op.getOperand(1); + Value on_false = op.getOperand(2); + rewriter.replaceOpWithNewOp(op, pred, on_true, on_false); } }; -// Rewrites `stablehlo.dot_general` -> `tfl.fully_connected` when it accepts -// uniform quantized tensors with per-axis quantized filter tensor (rhs). -// -// Conditions for the conversion: -// * Input and output tensors are per-tensor uniform quantized (i8->f32) -// tensors. -// * The filter tensor is constant a per-channel uniform quantized (i8->f32) -// tensor. The quantization dimension should be 1 (the non-contracting -// dimension). -// * The input tensor's rank is either 2 or 3. The last dimension of the input -// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. -// * The filter tensor's rank is 2. The contracting dimension should be the -// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. -// * Does not consider activation fusion. -// * Does not consider bias add fusion. -// -// TODO: b/294983811 - Merge this pattern into -// `RewriteFullIntegerQuantizedDotGeneralOp`. -// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands -// is not specified in the StableHLO dialect. Update the spec to allow this. -class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - +// Rewrites quantized stablehlo.concatenate to tfl.concatenation. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedConcatenateOp + : public OpRewritePattern { public: - LogicalResult match(stablehlo::DotGeneralOp op) const override { - const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = - op.getDotDimensionNumbers(); - if (const int num_rhs_contracting_dims = - dot_dimension_nums.getRhsContractingDimensions().size(); - num_rhs_contracting_dims != 1) { - LLVM_DEBUG(llvm::dbgs() - << "Expected number of contracting dimensions to be 1. Got: " - << num_rhs_contracting_dims << ".\n"); - return failure(); - } - - if (failed(MatchInput(op.getOperand(0)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match input for quantized dot_general op.\n"); - return failure(); - } - - if (failed(MatchFilter(op.getOperand(1)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match filter for quantized dot_general op.\n"); - return failure(); - } + using OpRewritePattern::OpRewritePattern; - if (failed(MatchOutput(op.getResult()))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match output for quantized dot_general op.\n"); - return failure(); - } - - return success(); + LogicalResult match(stablehlo::ConcatenateOp op) const override { + return success(IsOpFullyQuantized(op)); } - void rewrite(stablehlo::DotGeneralOp op, + void rewrite(stablehlo::ConcatenateOp op, PatternRewriter& rewriter) const override { - // Create the new filter constant - transpose filter value - // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for - // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas - // `tfl.fully_connected` accepts an OI format. - auto filter_constant_op = - cast(op.getOperand(1).getDefiningOp()); - - TFL::QConstOp new_filter_constant_op = - CreateTflConstOpForFilter(filter_constant_op, rewriter, - /*is_per_axis=*/true); - const Value input_value = op.getOperand(0); - const double input_scale = input_value.getType() - .cast() - .getElementType() - .cast() - .getScale(); - TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( - op.getLoc(), input_scale, new_filter_constant_op, rewriter, - /*is_per_axis=*/true); + Type output_type = op.getResult().getType(); + uint32_t axis = CastI64ToI32(op.getDimension()).value(); + rewriter.replaceOpWithNewOp( + op, output_type, op.getOperands(), axis, + /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + } +}; - const Value result_value = op.getResult(); - // Set to `nullptr` because this attribute only matters when the input is - // dynamic-range quantized. - const BoolAttr asymmetric_quantize_inputs = nullptr; - auto tfl_fully_connected_op = rewriter.create( - op.getLoc(), /*output=*/result_value.getType(), - /*input=*/input_value, /*filter=*/new_filter_constant_op.getResult(), - /*bias=*/bias_constant_op.getResult(), - /*fused_activation_function=*/rewriter.getStringAttr("NONE"), - /*weights_format=*/rewriter.getStringAttr("DEFAULT"), - /*keep_num_dims=*/rewriter.getBoolAttr(false), - asymmetric_quantize_inputs); +// Rewrites quantized stablehlo.pad to tfl.padv2. +// tfl.dilate is introduced in between when interior padding exists. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedPadOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; - rewriter.replaceAllUsesWith(result_value, - tfl_fully_connected_op.getResult(0)); - rewriter.eraseOp(op); + LogicalResult match(stablehlo::PadOp op) const override { + return success(IsOpFullyQuantized(op)); } - private: - static LogicalResult MatchInput(Value input) { - auto input_type = input.getType().cast(); - if (!input_type.hasRank() || - !(input_type.getRank() == 2 || input_type.getRank() == 3)) { - LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " - << input_type << ".\n"); - return failure(); + void rewrite(stablehlo::PadOp op, PatternRewriter& rewriter) const override { + Value input = op.getOperand(); + // If any of the interior padding is non-zero, operand should be dilated + // first, and then padded. + if (llvm::any_of(op.getInteriorPadding(), + [](int64_t pad) { return pad != 0; })) { + input = InsertDilateOp(op, rewriter); } - if (const auto input_element_type = input_type.getElementType(); - !IsI8F32UniformQuantizedType(input_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected an i8->f32 uniform quantized type. Got: " - << input_element_type << ".\n"); - return failure(); + TensorType operand_type = input.getType().cast(); + const int64_t rank = operand_type.getRank(); + // Shape of padding should be [rank, 2]. + SmallVector shape{rank, 2}; + TensorType padding_type = + operand_type.cloneWith(shape, rewriter.getI32Type()); + + ArrayRef padding_low = op.getEdgePaddingLow(); + ArrayRef padding_high = op.getEdgePaddingHigh(); + SmallVector padding_value; + for (int i = 0; i < rank; ++i) { + padding_value.push_back(CastI64ToI32(padding_low[i]).value()); + padding_value.push_back(CastI64ToI32(padding_high[i]).value()); } - return success(); + TensorType output_type = op.getResult().getType().cast(); + Value constant_values = op.getPaddingValue(); + auto padding_attr = DenseIntElementsAttr::get(padding_type, padding_value); + auto padding = + rewriter.create(op.getLoc(), padding_attr); + rewriter.replaceOpWithNewOp(op, output_type, input, padding, + constant_values); } - static LogicalResult MatchFilter(Value filter) { - auto filter_type = filter.getType().cast(); - if (!filter_type.hasRank() || filter_type.getRank() != 2) { - LLVM_DEBUG(llvm::dbgs() - << "Filter tensor expected to have a tensor rank of 2. Got: " - << filter_type << ".\n"); - return failure(); - } - - const Type filter_element_type = filter_type.getElementType(); - if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { - LLVM_DEBUG( - llvm::dbgs() - << "Expected a per-channel uniform quantized (i8->f32) type. Got: " - << filter_element_type << "\n"); - return failure(); + Value InsertDilateOp(stablehlo::PadOp op, PatternRewriter& rewriter) const { + Value input = op.getOperand(); + TensorType operand_type = input.getType().cast(); + const int64_t rank = operand_type.getRank(); + + ArrayRef dilate_shape(rank); + TensorType dilate_type = + operand_type.cloneWith(dilate_shape, rewriter.getI32Type()); + ArrayRef interior_padding_i64 = op.getInteriorPadding(); + SmallVector interior_padding_i32 = + CastI64ArrayToI32(interior_padding_i64).value(); + auto dilate_attr = + DenseIntElementsAttr::get(dilate_type, interior_padding_i32); + auto dilate = rewriter.create(op.getLoc(), dilate_attr); + + // Shape after dilation. + SmallVector dilated_shape(rank); + ArrayRef operand_shape = operand_type.getShape(); + for (int i = 0; i < rank; ++i) { + dilated_shape[i] = + operand_shape[i] + interior_padding_i64[i] * (operand_shape[i] - 1); } + TensorType output_type = op.getResult().getType().cast(); + Type dilated_output_type = output_type.clone(dilated_shape); + Value constant_values = op.getPaddingValue(); - if (filter_element_type.cast() - .getQuantizedDimension() != 1) { - LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 1. Got: " - << filter_element_type << "\n"); - return failure(); - } + return rewriter.create(dilate.getLoc(), dilated_output_type, + input, dilate, constant_values); + } +}; - if (Operation* filter_op = filter.getDefiningOp(); - filter_op == nullptr || !isa(filter_op)) { - LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); - return failure(); - } +// Rewrites quantized stablehlo.slice to tfl.slice or tfl.strided_slice. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedSliceOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; - return success(); + LogicalResult match(stablehlo::SliceOp op) const override { + return success(IsOpFullyQuantized(op)); } - static LogicalResult MatchOutput(Value output) { - const Type output_element_type = - output.getType().cast().getElementType(); - if (!IsI8F32UniformQuantizedType(output_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i8->f32) type. Got: " - << output_element_type << ".\n"); - return failure(); + void rewrite(stablehlo::SliceOp op, + PatternRewriter& rewriter) const override { + auto operand_type = op.getOperand().getType().cast(); + Type output_type = op.getResult().getType(); + const int64_t rank = operand_type.getRank(); + + ArrayRef idx_shape(rank); + TensorType idx_type = + operand_type.cloneWith(idx_shape, rewriter.getI32Type()); + + ArrayRef start_idx_i64 = op.getStartIndices(); + ArrayRef limit_idx_i64 = op.getLimitIndices(); + + SmallVector start_idx_i32 = + CastI64ArrayToI32(start_idx_i64).value(); + auto start_idx_attr = DenseIntElementsAttr::get(idx_type, start_idx_i32); + auto start_idx = + rewriter.create(op.getLoc(), start_idx_attr); + + SmallVector slice_size_i32(rank); + for (int i = 0; i < rank; ++i) { + slice_size_i32[i] = + CastI64ToI32(limit_idx_i64[i] - start_idx_i64[i]).value(); + } + auto slice_size_attr = DenseIntElementsAttr::get(idx_type, slice_size_i32); + auto slice_size = + rewriter.create(op.getLoc(), slice_size_attr); + + ArrayRef strides = op.getStrides(); + // If stride of every dimension is 1, create tfl.slice and return early. + // Otherwise, create tfl.strided_slice instead. + if (llvm::all_of(strides, [](int64_t stride) { return stride == 1; })) { + rewriter.replaceOpWithNewOp( + op, output_type, op.getOperand(), start_idx, slice_size); + return; } - return success(); + SmallVector stride_i32 = CastI64ArrayToI32(strides).value(); + auto stride_attr = DenseIntElementsAttr::get(idx_type, stride_i32); + auto stride = rewriter.create(op.getLoc(), stride_attr); + rewriter.replaceOpWithNewOp( + op, output_type, op.getOperand(), start_idx, slice_size, stride, + /*begin_mask=*/0, /*end_mask=*/0, + /*ellipsis_mask=*/0, /*new_axis_mask=*/0, /*shrink_axis_mask=*/0, + /*offset=*/false); } }; -// Rewrites `stablehlo.dot_general` to `tfl.fully_connected` or -// `tfl.batch_matmul` when it accepts uniform quantized tensors. -// -// Conditions for `tfl.fully_connected` conversion: -// * Input and output tensors are per-tensor uniform quantized (i8->f32) -// tensors. -// * The filter tensor is constant a per-tensor uniform quantized (i8->f32) -// tensor. The quantization dimension should be 1 (the non-contracting -// dimension). -// * The input tensor's rank is either 2 or 3. The last dimension of the input -// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. -// * The filter tensor's rank is 2. The contracting dimension should be the -// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. -// * Does not consider activation fusion. -// * Does not consider bias add fusion. -// TODO: b/580909703 - Include conversion conditions for `tfl.batch_matmul` op. -// -// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands -// is not specified in the StableHLO dialect. Update the spec to allow this. -class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - +// Rewrites quantized stablehlo.broadcast_in_dim to tfl.broadcast_to. +// tfl.transpose is introduced when broadcast_dimensions is not in ascending +// order. Also, tfl.expand_dims is introduced when input rank is smaller than +// output rank. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedBroadcastInDimOp + : public OpRewritePattern { public: - LogicalResult match(stablehlo::DotGeneralOp op) const override { - const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = - op.getDotDimensionNumbers(); - if (const int num_rhs_contracting_dims = - dot_dimension_nums.getRhsContractingDimensions().size(); - num_rhs_contracting_dims != 1) { - LLVM_DEBUG(llvm::dbgs() - << "Expected number of contracting dimensions to be 1. Got: " - << num_rhs_contracting_dims << ".\n"); - return failure(); - } + using OpRewritePattern::OpRewritePattern; - if (failed(MatchInput(op.getOperand(0)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match input for quantized dot_general op.\n"); - return failure(); - } + LogicalResult match(stablehlo::BroadcastInDimOp op) const override { + return success(IsOpFullyQuantized(op)); + } - if (failed(MatchFilter(op.getOperand(1)))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match filter for quantized dot_general op.\n"); - return failure(); - } + void rewrite(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + auto operand_type = op.getOperand().getType().cast(); + auto output_type = op.getResult().getType().cast(); + Value input = op.getOperand(); - if (failed(MatchOutput(op.getResult()))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to match output for quantized dot_general op.\n"); - return failure(); + // If broadcast_dimensions is not in ascending order, transpose first. + if (!llvm::is_sorted(op.getBroadcastDimensions())) { + input = InsertTransposeOp(op, rewriter); } - if (failed(MatchUsers(op.getResult()))) { - LLVM_DEBUG(llvm::dbgs() << "Failed to match subsequent requantize for " - "quantized dot_general op.\n"); - return failure(); + // If rank of operand is smaller than that of the output, expand dimensions + // before broadcasting. + if (operand_type.getRank() < output_type.getRank()) { + input = InsertExpandDimsOp(op, rewriter, input, output_type.getRank()); } - return success(); + SmallVector broadcast_shape = + CastI64ArrayToI32(output_type.getShape()).value(); + TensorType broadcast_shape_type = + output_type.cloneWith({output_type.getRank()}, rewriter.getI32Type()); + auto broadcast_shape_attr = + DenseIntElementsAttr::get(broadcast_shape_type, broadcast_shape); + auto shape = + rewriter.create(op.getLoc(), broadcast_shape_attr); + + rewriter.replaceOpWithNewOp(op, output_type, input, + shape); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { - // Create the new filter constant - transpose filter value - // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for - // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas - // `tfl.fully_connected` accepts an OI format. - auto filter_constant_op = - cast(op.getOperand(1).getDefiningOp()); + Value InsertTransposeOp(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const { + SmallVector sorted_dims = + llvm::to_vector(op.getBroadcastDimensions()); + llvm::sort(sorted_dims); + auto broadcast_dims = op.getBroadcastDimensions(); + SmallVector permutation( + llvm::map_range(broadcast_dims, [sorted_dims](int64_t dim) { + return static_cast(llvm::find(sorted_dims, dim) - + sorted_dims.begin()); + })); + auto operand_type = op.getOperand().getType().cast(); + TensorType perm_type = operand_type.cloneWith( + {static_cast(permutation.size())}, rewriter.getI32Type()); + auto perm_attr = DenseIntElementsAttr::get(perm_type, permutation); + auto perm = rewriter.create(op.getLoc(), perm_attr); + Value input = op.getOperand(); + + return rewriter.create(op.getLoc(), input, perm); + } - TFL::QConstOp new_filter_constant_op = CreateTflConstOpForFilter( - filter_constant_op, rewriter, /*is_per_axis=*/false); - const Value input_value = op.getOperand(0); - const double input_scale = input_value.getType() - .cast() - .getElementType() - .cast() - .getScale(); - TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( - op.getLoc(), input_scale, new_filter_constant_op, rewriter, - /*is_per_axis=*/false); + Value InsertExpandDimsOp(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter, Value input, + int64_t output_rank) const { + auto input_type = input.getType().cast(); + SmallVector input_shape(input_type.getShape()); + SmallVector input_dims = + llvm::to_vector(op.getBroadcastDimensions()); + + while (input_dims.size() < output_rank) { + int32_t dim_to_expand = 0; + for (int32_t i = 0; i < output_rank; ++i) { + if (!llvm::is_contained(input_dims, i)) { + dim_to_expand = i; + break; + } + } - auto output_op = op.getResult().getDefiningOp(); - Operation* requantize_op = *output_op->getResult(0).getUsers().begin(); - Operation* dequantize_op = *requantize_op->getResult(0).getUsers().begin(); + TensorType dim_type = input_type.cloneWith({static_cast(1)}, + rewriter.getI32Type()); + ArrayRef dims(dim_to_expand); + auto dim_attr = DenseIntElementsAttr::get(dim_type, dims); + auto dim = rewriter.create(op.getLoc(), dim_attr); - // Set to `nullptr` because this attribute only matters when the input is - // dynamic-range quantized. - const BoolAttr asymmetric_quantize_inputs = nullptr; - auto tfl_fully_connected_op = rewriter.create( - op.getLoc(), - /*output=*/ - requantize_op->getResult(0).getType(), // result_value.getType(), - /*input=*/input_value, /*filter=*/new_filter_constant_op.getResult(), - /*bias=*/bias_constant_op.getResult(), - /*fused_activation_function=*/rewriter.getStringAttr("NONE"), - /*weights_format=*/rewriter.getStringAttr("DEFAULT"), - /*keep_num_dims=*/rewriter.getBoolAttr(false), - asymmetric_quantize_inputs); + input_shape.insert(input_shape.begin() + dim_to_expand, 1); + TensorType expanded_type = input_type.clone(input_shape); + input = rewriter.create(op.getLoc(), expanded_type, + input, dim); + + // Update expanded dimension in the input dimensions for the next + // iteration. + input_dims.push_back(static_cast(dim_to_expand)); + } + return input; + } +}; - auto tfl_dequantize_op = rewriter.create( - op.getLoc(), dequantize_op->getResult(0).getType(), - tfl_fully_connected_op->getResult(0)); +// Rewrites quantized stablehlo.reduce_window with max to tfl.max_pool_2d. +// TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. +class RewriteQuantizedReduceWindowOpWithMax + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; - rewriter.replaceAllUsesWith(dequantize_op->getResult(0), - tfl_dequantize_op->getResult(0)); + LogicalResult MatchBinaryReduceFunction(Region& function) const { + Block& body = function.front(); + if (body.getNumArguments() != 2) return failure(); - rewriter.replaceAllUsesWith(op.getResult(), - tfl_fully_connected_op.getResult(0)); + auto return_op = dyn_cast(body.back()); + if (!return_op) return failure(); + if (return_op.getNumOperands() != 1) return failure(); - rewriter.eraseOp(op); + auto reduce_op = dyn_cast_or_null( + return_op.getOperands().front().getDefiningOp()); + if (!reduce_op) return failure(); + return success(reduce_op.getLhs() == body.getArgument(0) && + reduce_op.getRhs() == body.getArgument(1)); } - private: - static LogicalResult MatchInput(Value input) { - auto input_type = input.getType().cast(); - if (!input_type.hasRank() || - !(input_type.getRank() == 2 || input_type.getRank() == 3)) { - LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " - << input_type << ".\n"); + LogicalResult match(stablehlo::ReduceWindowOp op) const override { + // Check that the reduce-window is a max-reduce-window. + if (failed(MatchBinaryReduceFunction(op.getBody()))) { return failure(); } - if (const auto input_element_type = input_type.getElementType(); - !IsI8F32UniformQuantizedType(input_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected an i8->f32 uniform quantized type. Got: " - << input_element_type << ".\n"); + // Only 2d pooling is supported in TFLite. + if (op.getWindowDimensions().size() != 4) { return failure(); } - return success(); - } - - static LogicalResult MatchFilter(Value filter) { - auto filter_type = filter.getType().cast(); - if (!filter_type.hasRank() || filter_type.getRank() != 2) { - LLVM_DEBUG(llvm::dbgs() - << "Filter tensor expected to have a tensor rank of 2. Got: " - << filter_type << ".\n"); + // reduce_window op with dilations or padding will supported later. + // TODO: b/321099943 - Support reduce_window op with dilations and padding. + if (op.getBaseDilations().has_value() || + op.getWindowDilations().has_value() || op.getPadding().has_value()) { return failure(); } - const Type filter_element_type = filter_type.getElementType(); - if (!IsI8F32UniformQuantizedType(filter_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i8->f32) type. Got: " - << filter_element_type << "\n"); + // Window_dimensions and window_strides should have batch and channel + // dimension of 1 as they cannot be specified in tfl.max_pool_2d. + ArrayRef window_dims = op.getWindowDimensions(); + if (window_dims[0] != 1 || window_dims[3] != 1) { return failure(); } - - if (Operation* filter_op = filter.getDefiningOp(); - filter_op == nullptr || !isa(filter_op)) { - LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); - return failure(); + std::optional> window_strides = op.getWindowStrides(); + if (window_strides.has_value()) { + if ((*window_strides)[0] != 1 || (*window_strides)[3] != 1) { + return failure(); + } } - return success(); - } - - static LogicalResult MatchOutput(Value output) { - const Type output_element_type = - output.getType().cast().getElementType(); - if (!IsI32F32UniformQuantizedType(output_element_type)) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized (i32->f32) type. Got: " - << output_element_type << ".\n"); - return failure(); - } - return success(); + return success(IsOpFullyQuantized(op)); } - static LogicalResult MatchUsers(Value output) { - auto output_op = output.getDefiningOp(); - - if (!output_op->hasOneUse()) { - LLVM_DEBUG(llvm::dbgs() << "Expected output to be used only once.\n"); - return failure(); - } - // TODO: b/309896242 - Add support for fused op case. - if (Operation* requantize_op = dyn_cast_or_null( - *output_op->getResult(0).getUsers().begin())) { - const Type requantize_element_type = requantize_op->getResult(0) - .getType() - .cast() - .getElementType(); - if (!IsI8F32UniformQuantizedType(requantize_element_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected a quantize (i8->f32) type. Got: " - << requantize_element_type << ".\n"); - return failure(); - } - if (!isa( - *requantize_op->getResult(0).getUsers().begin())) { - LLVM_DEBUG(llvm::dbgs() << "Expected a dequantize type.\n"); - return failure(); - } + void rewrite(stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + Type result_type = op.getResult(0).getType(); + Value input = op.getOperand(0); + // Ops with padding is rejected in matching function, so we can use the + // padding to be 'VALID'. + StringAttr padding = rewriter.getStringAttr("VALID"); + + // Use NHWC format. + int32_t stride_h = 1; + int32_t stride_w = 1; + std::optional> window_strides = op.getWindowStrides(); + if (window_strides.has_value()) { + stride_h = CastI64ToI32((*window_strides)[1]).value(); + stride_w = CastI64ToI32((*window_strides)[2]).value(); } - return success(); + auto stride_h_attr = IntegerAttr::get(rewriter.getI32Type(), stride_h); + auto stride_w_attr = IntegerAttr::get(rewriter.getI32Type(), stride_w); + + ArrayRef window_dims = op.getWindowDimensions(); + auto window_w_attr = IntegerAttr::get(rewriter.getI32Type(), + CastI64ToI32(window_dims[2]).value()); + auto window_h_attr = IntegerAttr::get(rewriter.getI32Type(), + CastI64ToI32(window_dims[1]).value()); + StringAttr activation_function = rewriter.getStringAttr("NONE"); + + rewriter.replaceOpWithNewOp( + op, result_type, input, padding, stride_w_attr, stride_h_attr, + window_w_attr, window_h_attr, activation_function); } }; -void UniformQuantizedStablehloToTflPass::runOnOperation() { +void UniformQuantizedStableHloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); RewritePatternSet patterns(&ctx); patterns.add( - &ctx); + RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp, + RewriteQuantizedConvolutionOp, RewriteQuantizedTransposeOp, + RewriteQuantizedReshapeOp, RewriteQuantizedSelectOp, + RewriteQuantizedConcatenateOp, RewriteQuantizedPadOp, + RewriteQuantizedSliceOp, RewriteQuantizedBroadcastInDimOp, + RewriteQuantizedReduceWindowOpWithMax>(&ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " @@ -1289,11 +1721,11 @@ void UniformQuantizedStablehloToTflPass::runOnOperation() { } // namespace std::unique_ptr> -CreateUniformQuantizedStablehloToTflPass() { - return std::make_unique(); +CreateUniformQuantizedStableHloToTflPass() { + return std::make_unique(); } -static PassRegistration pass; +static PassRegistration pass; } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir index cbf7c3dd6cebfe..3c70fd1dfacf42 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/mix_tflite_stablehlo.mlir @@ -9,10 +9,8 @@ func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> { } } -// CHECK:module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { -// CHECK-NEXT: func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "exp"}} { +// CHECK: func.func @main(%arg0: tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "exp"}} { // CHECK-NEXT: %0 = stablehlo.logistic %arg0 : tensor<1x1x1x96xf32> // CHECK-NEXT: %1 = "tfl.exp"(%0) : (tensor<1x1x1x96xf32>) -> tensor<1x1x1x96xf32> // CHECK-NEXT: return %1 : tensor<1x1x1x96xf32> -// CHECK-NEXT: } -// CHECK-NEXT:} \ No newline at end of file +// CHECK-NEXT: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir index 75474e7ec8b268..76f778bcebec20 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir @@ -5,7 +5,9 @@ func.func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { ^bb0(%arg0: tensor<3x2xi32>): - // CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} + // CHECK: module attributes + // CHECK-SAME: tfl.description = "MLIR Converted." + // CHECK-SAME: tfl.schema_version = 3 : i32 // CHECK: %{{.*}} = "tfl.pseudo_const"() {value = dense<{{\[\[1, 2\], \[3, 4\], \[5, 6\]\]}}> : tensor<3x2xi32>} // CHECK-NEXT: [[SUB:%.*]] = tfl.sub %{{.*}}, %{{.*}} {fused_activation_function = "RELU6"} : tensor<3x2xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir index bf044517c7aa3f..f7fc2f3ff12c01 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo.mlir @@ -351,7 +351,7 @@ func.func @reduce_window(%arg0: tensor<1x160x1xf32>, %arg1: tensor) -> tens ^bb0(%arg23: tensor, %arg24: tensor): %1112 = stablehlo.add %arg23, %arg24 : tensor stablehlo.return %1112 : tensor - }) {base_dilations = dense<1> : tensor<3xi64>, padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 160, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<1x160x1xf32>, tensor) -> tensor<1x160x1xf32> + }) {base_dilations = array, padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<1x160x1xf32>, tensor) -> tensor<1x160x1xf32> return %0 : tensor<1x160x1xf32> } @@ -360,7 +360,7 @@ func.func @reduce_window(%arg0: tensor<1x160x1xf32>, %arg1: tensor) -> tens //CHECK-NEXT: ^bb0(%arg2: tensor, %arg3: tensor): //CHECK-NEXT: %1 = stablehlo.add %arg2, %arg3 : tensor //CHECK-NEXT: stablehlo.return %1 : tensor -//CHECK-NEXT{LITERAL}: }) {base_dilations = dense<1> : tensor<3xi64>, padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 160, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<1x160x1xf32>, tensor) -> tensor<1x160x1xf32> +//CHECK-NEXT{LITERAL}: }) {base_dilations = array, padding = dense<[[0, 0], [159, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<1x160x1xf32>, tensor) -> tensor<1x160x1xf32> //CHECK-NEXT: return %0 : tensor<1x160x1xf32> //CHECK-NEXT:} @@ -507,7 +507,7 @@ func.func @gather(%operand: tensor<3x4x2xi32>, %start_indices: tensor<2x3x2xi64> collapsed_slice_dims = [0], start_index_map = [1, 0], index_vector_dim = 2>, - slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + slice_sizes = array, indices_are_sorted = false } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> return %result : tensor<2x3x2x2xi32> @@ -515,7 +515,7 @@ func.func @gather(%operand: tensor<3x4x2xi32>, %start_indices: tensor<2x3x2xi64> // CHECK: func.func private @gather(%arg0: tensor<3x4x2xi32>, %arg1: tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> { -// CHECK-NEXT: %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> +// CHECK-NEXT: %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> // CHECK-NEXT: return %0 : tensor<2x3x2x2xi32> // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir index f3c64f67fc5f9b..e9708e0f14a877 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/stablehlo_const.mlir @@ -8,9 +8,12 @@ module attributes {tfl.metadata = {"keep_stablehlo_constant" = "true"}} { } } -//CHECK:module attributes {tfl.description = "MLIR Converted.", tfl.metadata = {keep_stablehlo_constant = "true"}, tfl.schema_version = 3 : i32} { -//CHECK-NEXT: func.func @main() -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {outputs = "stablehlo.constant"}} { -//CHECK-NEXT: %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x1x96xf32> -//CHECK-NEXT: return %0 : tensor<1x1x1x96xf32> -//CHECK-NEXT: } -//CHECK-NEXT:} \ No newline at end of file +// CHECK: module attributes { +// CHECK-SAME: tfl.metadata +// CHECK-SAME: keep_stablehlo_constant = "true" + +// CHECK-NEXT: func.func @main() -> tensor<1x1x1x96xf32> attributes {tf.entry_function = {outputs = "stablehlo.constant"}} { +// CHECK-NEXT: %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x1x96xf32> +// CHECK-NEXT: return %0 : tensor<1x1x1x96xf32> +// CHECK-NEXT: } +// CHECK-NEXT:} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir index 5d1566cf121590..07738c1102f767 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/tf_variant_type.mlir @@ -4,6 +4,5 @@ func.func @main(%arg0 : tensor>>, %arg1: tensor>> } -// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} // CHECK: func.func @main(%[[ARG0:.*]]: tensor>>, %[[ARG1:.*]]: tensor>>) -> tensor>> // CHECK-NEXT: return %[[ARG0]] : tensor>> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 40db57950cc30f..46ff509b7cc46e 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s +// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' -tfl-optimize | FileCheck %s func.func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> { %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index ed5370b7cda3dc..9da0e13c0471ac 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -377,6 +377,28 @@ func.func @DontFuseMulIntoFullyConnectedForLargeFilter(%arg0: tensor<128x256000x } +// CHECK-LABEL: @skipFuseMulIntoFullyConnected +func.func @skipFuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> (tensor<1x8xf32>, tensor<4x2xf32>) { + %cst0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst1 = arith.constant dense<2.0> : tensor<2xf32> + %cst2 = arith.constant dense<[1.0, 2.0]> : tensor<2xf32> + %cst3 = arith.constant dense<[1, 8]> : tensor<2xi32> + + %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + %1 = "tfl.reshape"(%0, %cst3) : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<1x8xf32> + %2 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + + func.return %1, %2 : tensor<1x8xf32>, tensor<4x2xf32> + // CHECK: %cst = arith.constant dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32> + // CHECK: %cst_0 = arith.constant dense<2.000000e+00> : tensor<2xf32> + // CHECK: %cst_1 = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + // CHECK: %cst_2 = arith.constant dense<[1, 8]> : tensor<2xi32> + // CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + // CHECK: %1 = "tfl.reshape"(%0, %cst_2) : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<1x8xf32> + // CHECK: %2 = tfl.mul(%0, %cst_1) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + // CHECK: return %1, %2 : tensor<1x8xf32>, tensor<4x2xf32> +} + // CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs func.func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { %cst2 = arith.constant dense<1.5> : tensor @@ -762,26 +784,6 @@ func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x // CHECK: return %0 : tensor<1x3x6x5x8192xf32> } -// CHECK-LABEL: @FuseTransposeIntoBMM_RHS -func.func @FuseTransposeIntoBMM_RHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> { - %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> - %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x1440x256xf32>, tensor<3xi32>) -> tensor<1x256x1440xf32> - %33 = "tfl.batch_matmul"(%arg0, %32) {adj_x = false, adj_y = false} : (tensor<1x4x1440x256xf32>, tensor<1x256x1440xf32>) -> tensor<1x4x1440x1440xf32> - return %33 : tensor<1x4x1440x1440xf32> - // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x1440x256xf32>, tensor<1x1440x256xf32>) -> tensor<1x4x1440x1440xf32> - // CHECK: return %0 : tensor<1x4x1440x1440xf32> -} - -// CHECK-LABEL: @FuseTransposeIntoBMM_RHS2 -func.func @FuseTransposeIntoBMM_RHS2(%arg0: tensor, %arg1: tensor) -> tensor { - %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> - %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor, tensor<3xi32>) -> tensor - %33 = "tfl.batch_matmul"(%arg0, %32) {adj_x = false, adj_y = false} : (tensor, tensor) -> tensor - return %33 : tensor - // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor, tensor) -> tensor - // CHECK: return %0 : tensor -} - // CHECK-LABEL: @FuseTransposeIntoBMM_LHS func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x256x256xf32> { %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> @@ -3897,3 +3899,163 @@ func.func @NoReorderNCHWTransposeAddNotBias(%arg0: tensor<1x40x40x1xf32>, %filte // CHECK: %[[add:.*]] = tfl.add %[[transpose]], // CHECK: return %[[add]] } + +// CHECK-LABEL: @ConvertStridedSliceToSlice +func.func @ConvertStridedSliceToSlice(%arg0: tensor<2x3872x1x128xf32>) -> tensor<1x3872x1x128xf32> { + %44 = arith.constant dense<0> : tensor<4xi32> + %45 = arith.constant dense<[1, 3872, 1, 128]> : tensor<4xi32> + %46 = arith.constant dense<1> : tensor<4xi32> + %47 = "tfl.strided_slice"(%arg0, %44, %45, %46) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<2x3872x1x128xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3872x1x128xf32> + func.return %47 : tensor<1x3872x1x128xf32> + + // CHECK: %[[slice:.*]] = "tfl.slice" + // CHECK: return %[[slice]] +} + +// CHECK-LABEL: @FuseExcessBroadcastingOnReshapes +func.func @FuseExcessBroadcastingOnReshapes(%arg0: tensor<1x8xf32>) -> tensor<1x1x1x128xf32> { + %cst = arith.constant dense<[1, 1, 1, 8, 1, 1]> : tensor<6xi32> + %cst_0 = arith.constant dense<[1, 1, 1, 8, 16, 1]> : tensor<6xi32> + %cst_1 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> + %0 = "tfl.reshape"(%arg0, %cst) : (tensor<1x8xf32>, tensor<6xi32>) -> tensor<1x1x1x8x1x1xf32> + %1 = "tfl.broadcast_to"(%0, %cst_0) : (tensor<1x1x1x8x1x1xf32>, tensor<6xi32>) -> tensor<1x1x1x8x16x1xf32> + %2 = "tfl.reshape"(%1, %cst_1) : (tensor<1x1x1x8x16x1xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> + return %2 : tensor<1x1x1x128xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<8x16xf32> + // CHECK: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> + // CHECK: %cst_1 = arith.constant dense<[8, 1]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x8xf32>, tensor<2xi32>) -> tensor<8x1xf32> + // CHECK: %1 = tfl.mul(%0, %cst) {fused_activation_function = "NONE"} : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<8x16xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> + // CHECK: return %2 : tensor<1x1x1x128xf32> +} + +// CHECK-LABEL: @FuseExcessBroadcastingOnReshapesDynamicShapes +func.func @FuseExcessBroadcastingOnReshapesDynamicShapes(%arg0: tensor, %arg1: tensor<6xi32>, %arg2: tensor<6xi32>, %arg3: tensor<2xi32>) -> tensor { + %1196 = "tfl.reshape"(%arg0, %arg1) : (tensor, tensor<6xi32>) -> tensor<1x?x1x10x1x1xf32> + %1197 = "tfl.broadcast_to"(%1196, %arg2) : (tensor<1x?x1x10x1x1xf32>, tensor<6xi32>) -> tensor<1x?x1x10x5x1xf32> + %1198 = "tfl.reshape"(%1197, %arg3) : (tensor<1x?x1x10x5x1xf32>, tensor<2xi32>) -> tensor + return %1198 : tensor + + // CHECK: %0 = "tfl.reshape"(%arg0, %arg1) : (tensor, tensor<6xi32>) -> tensor<1x?x1x10x1x1xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %arg2) : (tensor<1x?x1x10x1x1xf32>, tensor<6xi32>) -> tensor<1x?x1x10x5x1xf32> + // CHECK: %2 = "tfl.reshape"(%1, %arg3) : (tensor<1x?x1x10x5x1xf32>, tensor<2xi32>) -> tensor + // CHECK: return %2 : tensor +} + +// CHECK-LABEL: @broadcast_to_f32_low_dim +func.func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> +} + +// CHECK-LABEL: @broadcast_to_i32_low_dim +func.func @broadcast_to_i32_low_dim(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: return %0 : tensor<3x3xi32> +} + +// CHECK-LABEL: @broadcast_to_low_dim_with_unknown_shape +func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> +} + +// CHECK-LABEL: @broadcast_to_i16_low_dim +func.func @broadcast_to_i16_low_dim(%arg0: tensor<3xi16>, %arg1: tensor<2xi32>) -> tensor<3x3xi16> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> + return %0 : tensor<3x3xi16> + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi16> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> + // CHECK: return %0 : tensor<3x3xi16> +} + +// CHECK-LABEL: @broadcast_to_i32_low_dim_with_unknown_output +func.func @broadcast_to_i32_low_dim_with_unknown_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<*xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> + // CHECK: %cst = arith.constant dense<1> : tensor + // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<*xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> + // CHECK: return %1 : tensor<*xi32> +} + +// CHECK-LABEL: @broadcast_to_ui32 +func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> + return %0 : tensor<10xui32> + // CHECK: %cst = arith.constant dense<1> : tensor<10xui32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor, tensor<10xui32>) -> tensor<10xui32> + // CHECK: return %0 : tensor<10xui32> +} + +// CHECK-LABEL: @broadcast_to_f32 +func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: return %0 : tensor<3x3xf32> +} + +// CHECK-LABEL: @broadcast_to_i32 +func.func @broadcast_to_i32(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> + // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> + // CHECK: %0 = tfl.mul(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> + // CHECK: return %0 : tensor<3x3xi32> +} + +// CHECK-LABEL: @broadcast_to_i32_with_dynamic_shape_and_output +func.func @broadcast_to_i32_with_dynamic_shape_and_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x?xi32> { + %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x?xi32> + return %0 : tensor<3x?xi32> + // CHECK: %cst = arith.constant dense<1> : tensor + // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<3x?xi32> + // CHECK: %1 = tfl.mul(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> + // CHECK: return %1 : tensor<3x?xi32> +} + +// CHECK-LABEL: @broadcast_to_ui32_with_dynamic_output +func.func @broadcast_to_ui32_with_dynamic_output(%arg0: tensor<1xi32>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xui32> + %0 = "tfl.broadcast_to"(%cst, %arg0) : (tensor<1xui32>, tensor<1xi32>) -> tensor + return %0 : tensor + + // CHECK: %cst = arith.constant dense<0> : tensor<1xui32> + // CHECK: %0 = "tfl.broadcast_to"(%cst, %arg0) : (tensor<1xui32>, tensor<1xi32>) -> tensor + // CHECK: return %0 : tensor +} + + +// CHECK-LABEL: @ConvertStridedSliceToSliceNeg +func.func @ConvertStridedSliceToSliceNeg(%arg0: tensor<5x5x5x5xf32>) -> tensor<*xf32> { + %44 = arith.constant dense<[5, 5, 5, 5]> : tensor<4xi32> + %45 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi32> + %46 = arith.constant dense<1> : tensor<4xi32> + %47 = "tfl.strided_slice"(%arg0, %44, %45, %46) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<5x5x5x5xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32> + func.return %47 : tensor<*xf32> + + // CHECK-NOT: %[[slice:.*]] = "tfl.slice" +} + +// CHECK-LABEL: @StridedSliceToSliceBeginNeg +func.func @StridedSliceToSliceBeginNeg(%arg0: tensor<5x5x5x5xf32>) -> tensor<*xf32> { + %44 = arith.constant dense<[-5, 0, 0, 0]> : tensor<4xi32> + %45 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi32> + %46 = arith.constant dense<1> : tensor<4xi32> + %47 = "tfl.strided_slice"(%arg0, %44, %45, %46) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<5x5x5x5xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<*xf32> + func.return %47 : tensor<*xf32> + + // CHECK-NOT: %[[slice:.*]] = "tfl.slice" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir index 0c9f058c1912c9..b35355524127dc 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir @@ -103,8 +103,8 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %fc : tensor<1x112x112x512xf32> // CHECK-DAG: %[[w:.*]] = arith.constant dense<1.270000e+02> : tensor<512x12xf32> -// CHECK-DAG: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>>} -// CHECK-DAG: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<512x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<512x12xf32> +// CHECK-DAG: %[[q_w:.*]] = "tfl.quantize"(%[[w]]) {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, +// CHECK-DAG: %[[dq_w:.*]] = "tfl.dequantize"(%[[q_w]]) : (tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00, // CHECK-DAG: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) { // CHECK-NOT: fused_activation_function = "NONE" diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir index d2e04734e0e2e6..15ede0019e12d6 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir @@ -298,8 +298,8 @@ func.func @QuantizeFullyConnectedOp(%arg0: tensor<1x3xf32>) -> (tensor<1x1xf32>) // CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<1x!quant.uniform>, volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x!quant.uniform>) -> tensor<1xf32> // CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : tensor<1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32, {{.*}}>> -// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> +// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32:0, {{.*}}>>) -> tensor<1x3xf32> // CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x3x!quant.uniform>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q3]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%[[dq3]], %[[dq2]], %[[dq1]]) {{{.*}}} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1xf32>) -> tensor<1x1xf32> @@ -324,8 +324,8 @@ func.func @QuantizeReshapeAndFullyConnectedOp(%arg0: tensor<1x1x3xf32>) -> (tens // CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<1x!quant.uniform>, volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x!quant.uniform>) -> tensor<1xf32> // CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : tensor<1x3xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32, {{.*}}>> -// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<1x3x!quant.uniform:f32:0, {{.*}}>>, volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform:f32:0, {{.*}}>> +// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x3x!quant.uniform:f32:0, {{.*}}>>) -> tensor<1x3xf32> // CHECK-NEXT: %[[cst_1:.*]] = arith.constant dense<[-1, 3]> : tensor<2xi32> // CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x3x!quant.uniform>, volatile} : (tensor<1x1x3xf32>) -> tensor<1x1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q3]]) : (tensor<1x1x3x!quant.uniform>) -> tensor<1x1x3xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index 6e9ca99e11f492..882b335135cf74 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -166,20 +166,20 @@ func.func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 // CHECK-LABEL: QuantizeFullyConnected // PerTensor-LABEL: QuantizeFullyConnected -func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { - %w = arith.constant dense<127.0> : tensor<32x12xf32> - %b = arith.constant dense<0.0> : tensor<32xf32> - %fc = "tfl.fully_connected"(%arg0, %w, %b) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x12xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> - func.return %fc : tensor<1x112x112x32xf32> - -// CHECK: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<32x12xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>, volatile} -// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> +func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x4xf32> { + %w = arith.constant dense<127.0> : tensor<4x12xf32> + %b = arith.constant dense<0.0> : tensor<4xf32> + %fc = "tfl.fully_connected"(%arg0, %w, %b) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<4x12xf32>, tensor<4xf32>) -> tensor<1x112x112x4xf32> + func.return %fc : tensor<1x112x112x4xf32> + +// CHECK: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<4x12xf32> +// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<4x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>, volatile} +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<4x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>) -> tensor<4x12xf32> // CHECK: "tfl.fully_connected"(%arg0, %[[dq]] -// PerTensor: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<32x12xf32> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>, volatile} -// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> +// PerTensor: %[[cst:.*]] = arith.constant dense<1.270000e+02> : tensor<4x12xf32> +// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<4x12x!quant.uniform:f32, 1.000000e+00>>, volatile} +// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<4x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<4x12xf32> // PerTensor: "tfl.fully_connected"(%arg0, %[[dq]] } @@ -215,8 +215,8 @@ func.func @bias_adjust_pertensor(%arg0: tensor<1x2xf32>) -> (tensor<1x2xf32>) { func.return %fc : tensor<1x2xf32> // CHECK-DAG: %[[weight:.*]] = arith.constant dense<{{\[\[}}0.000000e+00, 1.000000e+00] // CHECK-DAG: %[[bias:.*]] = arith.constant dense<[0.000000e+00, 2147364.75]> -// CHECK-DAG: %[[b_q:.*]] = "tfl.quantize"(%[[bias]]){{.*}}quant.uniform> -// CHECK-DAG: %[[w_q:.*]] = "tfl.quantize"(%[[weight]]){{.*}}quant.uniform:f32, 19998.892343977564>> +// CHECK-DAG: %[[b_q:.*]] = "tfl.quantize"(%[[bias]]){{.*}}quant.uniform> +// CHECK-DAG: %[[w_q:.*]] = "tfl.quantize"(%[[weight]]){{.*}}quant.uniform:f32:0, {0.0078740157480314959,19998.892343977564}>> // CHECK-DAG: %[[b_dq:.*]] = "tfl.dequantize"(%[[b_q]]) // CHECK-DAG: %[[w_dq:.*]] = "tfl.dequantize"(%[[w_q]]) // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%[[input:.*]], %[[w_dq]], %[[b_dq]]) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 2a4b2af88f5319..cce986eb8f1a8e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -1,5 +1,6 @@ // RUN: tf-opt %s -tfl-prepare-quantize="quantize-allowlist=quantize_float_placeholder_only,not_reset_input" | FileCheck %s // RUN: tf-opt %s -tfl-prepare-quantize="disable-set-input-nodes-quantization-params=true" | FileCheck --check-prefix=MixedPrecision %s +// RUN: tf-opt %s -tfl-prepare-quantize="is-qdq-conversion=true" | FileCheck --check-prefix=QDQ %s // CHECK-LABEL: main // Uses `main` function to match the default target function of QuantSpecs and @@ -394,6 +395,32 @@ func.func @NotRescaleLogistic(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + func.return %1 : tensor<1x6x6x16xf32> + +// QDQ: %0 = "tfl.dequantize"(%arg0) +// QDQ: %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> +// QDQ-NOT:"tfl.quantize" +// QDQ: return %1 : tensor<1x6x6x16xf32> +} + +// QDQ-LABEL: QDQNoQuantizeSoftmax +func.func @QDQNoQuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + func.return %1 : tensor<1x6x6x16xf32> + +// QDQ: %0 = "tfl.dequantize"(%arg0) +// QDQ: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> +// QDQ-NOT: "tfl.quantize" +// QDQ: return %1 : tensor<1x6x6x16xf32> +} + // CHECK-LABEL: QuantizeL2Norm func.func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 0769e768507ee7..bfbcbd573cb0e0 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -440,34 +440,13 @@ func.func @StridedSliceShrinkAxisAndNewAxisMaskBothSet(%arg0: tensor<6x7x8xf32>) // CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%[[RESHAPE]], %[[BEGIN]], %[[END]], %[[STEP]]) <{begin_mask = 26 : i64, ellipsis_mask = 0 : i64, end_mask = 26 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<6x1x7x1x8xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<1x4x1x8xf32> } -func.func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - func.return %0: tensor<3x3xf32> - -// CHECK-LABEL: broadcast_to_f32_low_dim -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> -} - -func.func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { - %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> - func.return %0: tensor<3x3xi32> - -// CHECK-LABEL: broadcast_to_i32_low_dim -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<3x3xi32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK: return [[MUL]] : tensor<3x3xi32> -} - func.func @broadcast_to_i16_low_dim(%input: tensor<3xi16>, %shape: tensor<2xi32>) -> tensor<3x3xi16> { %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> func.return %0: tensor<3x3xi16> // CHECK-LABEL: broadcast_to_i16_low_dim -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<3x3xi16> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> -// CHECK: return [[MUL]] : tensor<3x3xi16> +// CHECK: %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> +// CHECK: return %0 : tensor<3x3xi16> } func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { @@ -475,9 +454,8 @@ func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: func.return %0: tensor<3x3xf32> // CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> +// CHECK: %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> +// CHECK: return %0 : tensor<3x3xf32> } func.func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> { @@ -485,10 +463,8 @@ func.func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, % func.return %0: tensor<*xi32> // CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor -// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<*xi32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> -// CHECK: return [[MUL]] : tensor<*xi32> +// CHECK: %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> +// CHECK: return %0 : tensor<*xi32> } func.func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> { @@ -517,16 +493,6 @@ func.func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6x // CHECK: "tf.BroadcastTo"(%arg0, %arg1) } -func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> - func.return %0: tensor<10xui32> - -// CHECK-LABEL: broadcast_to_ui32 -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<10xui32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor, tensor<10xui32>) -> tensor<10xui32> -// CHECK: return [[MUL]] : tensor<10xui32> -} - // CHECK-LABEL: xla_conv_v2 func.func @xla_conv_v2(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1") @@ -541,26 +507,6 @@ func.func @xla_conv_v2(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { // CHECK: return %[[RES]] } -func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - func.return %0: tensor<3x3xf32> - -// CHECK-LABEL: broadcast_to_f32 -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> -} - -func.func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { - %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> - func.return %0: tensor<3x3xi32> - -// CHECK-LABEL: broadcast_to_i32 -// CHECK: [[CST:%.*]] = arith.constant dense<1> : tensor<3x3xi32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK: return [[MUL]] : tensor<3x3xi32> -} - // CHECK-LABEL: lower_rfft_to_rfft2d func.func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1xi32>) -> tensor<10x20x30xcomplex> { %0 = "tf.RFFT"(%input, %fft_len) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex> diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir index 15684bc4bd2204..ad4ff5a129f4a2 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir @@ -88,7 +88,7 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.return %fc : tensor<1x112x112x512xf32> // CHECK: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[w]], %[[b]]) { // CHECK-NOT: fused_activation_function = "NONE", // CHECK-SAME: asymmetric_quantize_inputs = true, @@ -102,8 +102,8 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 // PerTensor: return %[[fc:.*]] // PerChannelWeightOnly: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<512xf32> -// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32, 1.000000e+00>> -// PerChannelWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<512x12x!quant.uniform:f32, 1.000000e+00>> +// PerChannelWeightOnly: %[[w:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 +// PerChannelWeightOnly: %[[dq_w:.*]] = "tfl.dequantize"(%[[w]]) : (tensor<512x12x!quant.uniform:f32:0, {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00 // PerChannelWeightOnly: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[dq_w]], %[[b]]) { // PerChannelWeightOnly-NOT: fused_activation_function = "NONE", // PerChannelWeightOnly-SAME: asymmetric_quantize_inputs = true, diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 6da1d1e546ea32..749e5f383b289b 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -19,21 +19,22 @@ limitations under the License. #include #include +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/experimental/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -47,7 +48,7 @@ CreateTFExecutorToControlDialectConversion(); namespace tensorflow { namespace { // Data layout supported by TFLite. -const char kTFLiteDataLayout[] = "NHWC"; +constexpr mlir::StringRef kTFLiteDataLayout = "NHWC"; } // namespace void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, @@ -136,12 +137,26 @@ void AddDynamicRangeQuantizationPasses(const mlir::TFL::PassConfig& pass_config, mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true)); } -void AddConvertHloToTfPass(std::string entry_function_name, - const mlir::TFL::PassConfig& pass_config, - mlir::OpPassManager* pass_manager) { - pass_manager->addPass( +void AddPreQuantizationStableHloToTfPasses( + const mlir::StringRef entry_function_name, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager) { + pass_manager.addPass( mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + // Add CHLO to StableHLO Decompositions: + // This is needed since we are relying on XlaCallModule uses MHLO + // specific features like mhlo::ErfOp which aren't supported + // in StableHLO, but we have CHLO->StableHLO decompositions to legalize. + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + pass_manager.addPass( + mlir::stablehlo::experimental::createChloRecomposeOpsPass()); + pass_manager.addNestedPass( + mlir::mhlo::createChloLegalizeToHloBasisOpsPass()); + pass_manager.addNestedPass( + mlir::mhlo::createChloLegalizeToHloPass()); + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + // The following two passes find specific uniform quantization patterns in // StableHLO and converts them to TFLite ops that accept or produce uniform // quantized types. They only target a specific set of models that contain @@ -153,65 +168,93 @@ void AddConvertHloToTfPass(std::string entry_function_name, // There are future plans to make the framework to directly produce StableHLO // uniform quantized ops and deprecate `ComposeUniformQuantizedTypePass`. If // no quantization patterns are found, it is a no-op. - pass_manager->addPass(mlir::odml::CreateComposeUniformQuantizedTypePass()); - pass_manager->addNestedPass( - mlir::odml::CreateUniformQuantizedStablehloToTflPass()); + pass_manager.addPass(mlir::odml::CreateComposeUniformQuantizedTypePass()); + pass_manager.addNestedPass( + mlir::odml::CreateUniformQuantizedStableHloToTflPass()); - pass_manager->addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // Legalize jax random to tflite custom op. // The CreateLegalizeJaxRandom Pass has to stay at because we need to replace // the random function body before being inlined. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::TFL::CreateLegalizeJaxRandomPass()); // Canonicalize, CSE etc. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::createCanonicalizerPass()); - pass_manager->addNestedPass(mlir::createCSEPass()); + pass_manager.addNestedPass(mlir::createCSEPass()); // DCE for private symbols. - pass_manager->addPass(mlir::createSymbolDCEPass()); + pass_manager.addPass(mlir::createSymbolDCEPass()); - pass_manager->addPass(mlir::TF::CreateStripNoinlineAttributePass()); + pass_manager.addPass(mlir::TF::CreateStripNoinlineAttributePass()); // Add inline pass. - pass_manager->addPass(mlir::createInlinerPass()); + pass_manager.addPass(mlir::createInlinerPass()); // Expands mhlo.tuple ops. - pass_manager->addPass( - mlir::mhlo::createExpandHloTuplesPass(entry_function_name)); + pass_manager.addPass( + mlir::mhlo::createExpandHloTuplesPass(entry_function_name.str())); // Flatten tuples for control flows. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::mhlo::createFlattenTuplePass()); - mlir::odml::AddMhloOptimizationPasses(*pass_manager); + mlir::odml::AddMhloOptimizationPasses(pass_manager, + pass_config.enable_stablehlo_quantizer); // Undo the MHLO::BroadcastInDimOp folding pattern on splat constants. This // pass must be added right before the legalization because pattern rewriter // driver applies folding by default. - // TODO(b/295966255): Remove this pass after moving MHLO folders to a separate - // pass. - pass_manager->addPass(mlir::odml::CreateUnfoldSplatConstantPass()); + // TODO: b/295966255 - Remove this pass after moving MHLO folders to a + // separate pass. + pass_manager.addPass(mlir::odml::CreateUnfoldSplatConstantPass()); + + if (pass_config.enable_stablehlo_quantizer) { + // When using StableHLO Quantizer, MHLO ops should be transformed back into + // StableHLO because the quantizer takes StableHLO dialect as its input. + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + } +} + +void AddPostQuantizationStableHloToTfPasses( + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager) { + if (pass_config.enable_stablehlo_quantizer) { + // StableHLO Quantizer emits quantized StableHLO module serialized within a + // XlaCallModule op. Add this pass to extract StableHLO module from the + // XlaCallModuleOp. + pass_manager.addPass( + mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + + // Convert StableHLO -> TFLite for fused quantization patterns early so that + // quantized types do not go through the TF dialect which doesn't support + // quantized types. + pass_manager.addNestedPass( + mlir::odml::CreateUniformQuantizedStableHloToTflPass()); + + // StableHLO -> MHLO + pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + } // TFLite dialect passes. if (!pass_config.disable_hlo_to_tfl_conversion) { - pass_manager->addPass(mlir::odml::CreateLegalizeHloToTfLitePass()); + pass_manager.addPass(mlir::odml::CreateLegalizeHloToTfLitePass()); } // TF dialect passes - pass_manager->addPass(mlir::odml::CreateLegalizeHloToTfPass()); + pass_manager.addPass(mlir::odml::CreateLegalizeHloToTfPass()); // folds tf.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TF // legalization; otherwise other passes like `ConvertTFBroadcastTo` will // constant fold the newly generated TF broadcast ops and materialize the // weights. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::TF::CreateBroadcastFoldPass()); // Canonicalization after TF legalization. - pass_manager->addNestedPass( + pass_manager.addNestedPass( mlir::createCanonicalizerPass()); // Legalize all remaining mhlo ops to stableHLO - pass_manager->addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); } // This is the early part of the conversion in isolation. This enables a caller @@ -220,11 +263,6 @@ void AddConvertHloToTfPass(std::string entry_function_name, void AddPreVariableFreezingTFToTFLConversionPasses( const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager) { - if (pass_config.enable_hlo_to_tf_conversion) { - // TODO(b/194747383): We need to valid that indeed the "main" func is - // presented. - AddConvertHloToTfPass("main", pass_config, pass_manager); - } // This pass wraps all the tf.FakeQuant ops in a custom op so they are not // folded before being converted to tfl.quantize and tfl.dequantize ops. auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps(); @@ -266,7 +304,7 @@ void AddPreVariableFreezingTFToTFLConversionPasses( // This decomposes resource ops like ResourceGather into read-variable op // followed by gather. This is used when the saved model import path is used - // during which resources dont get frozen in the python layer. + // during which resources don't get frozen in the python layer. pass_manager->addNestedPass( mlir::TFDevice::CreateDecomposeResourceOpsPass()); @@ -375,7 +413,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( // Force layout supported by TFLite, this will transpose the data // to match 'kTFLiteDataLayout' mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; - layout_optimization_options.force_data_format = kTFLiteDataLayout; + layout_optimization_options.force_data_format = kTFLiteDataLayout.str(); layout_optimization_options.skip_fold_transpose_in_ops = true; mlir::TF::CreateLayoutOptimizationPipeline( pass_manager->nest(), layout_optimization_options); @@ -436,7 +474,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( // completed. Add either full integer quantization or dynamic range // quantization passes based on quant_specs. if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses() || - pass_config.qdq_conversion_mode != + pass_config.quant_specs.qdq_conversion_mode != mlir::quant::QDQConversionMode::kQDQNone) { AddQuantizationPasses(pass_config, *pass_manager); // Remove unnecessary QDQs while handling QAT models. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h index 50bf75023a808a..8de2142f0ebd83 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -32,6 +33,25 @@ void AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir, const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager); +// Adds the first portion of StableHLO->TF passes happening before quantization. +// The `pass_manager` that runs on a `mlir::ModuleOp` expects a graph containing +// a `mlir::TF::XlaCallModuleOp` with serialized StableHLO module. The resulting +// `mlir::ModuleOp` after running these passes will be an MHLO module, or a +// StableHLO module if `pass_config.enable_stablehlo_quantizer` is `true`. This +// is because StableHLO Quantizer accepts StableHLO modules. +void AddPreQuantizationStableHloToTfPasses( + mlir::StringRef entry_function_name, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); + +// Adds the second portion of StableHlo->TF passes happening after quantization. +// The input module is expected to be an MHLO module, or a quantized StableHLO +// graph (expressed as `mlir::TF::XlaCallModuleOp`s) if +// `pass_config.enable_stablehlo_quantizer` is `true`. +void AddPostQuantizationStableHloToTfPasses( + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); + // This is the early part of the conversion in isolation. This enables a caller // to inject more information in the middle of the conversion before resuming it // (like freezing variables for example). diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index dc0ae41ba49a2b..892eed27385035 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -17,56 +17,55 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_split.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" -#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" #include "xla/translate/hlo_to_mhlo/translate.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tsl/platform/statusor.h" using mlir::MLIRContext; using mlir::ModuleOp; using mlir::func::FuncOp; -using tsl::StatusOr; // Debugging flag to print function mapping in the flatbuffer. // NOLINTNEXTLINE @@ -170,7 +169,7 @@ int main(int argc, char **argv) { context.appendDialectRegistry(registry); } - StatusOr> module; + absl::StatusOr> module; std::unordered_set tags; tensorflow::GraphImportConfig specs; @@ -321,7 +320,7 @@ int main(int argc, char **argv) { if (bundle) session = bundle->GetSession(); auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.value().get(), output_mlir, toco_flags, pass_config, tags, - /*saved_model_dir=*/"", session, &result, serialize_stablehlo_ops); + /*saved_model_dir=*/"", bundle.get(), &result, serialize_stablehlo_ops); if (!status.ok()) { llvm::errs() << status.message() << '\n'; return kTrFailure; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 1b3c5d21ba3dfb..b6a08b35d69445 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include #include #include #include @@ -23,60 +24,75 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/debug/debug.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/experimental/remat/metadata_util.h" #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" +#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep namespace tensorflow { namespace { + using mlir::MLIRContext; using mlir::ModuleOp; using mlir::Operation; using mlir::OwningOpRef; -using tsl::StatusOr; +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::quantization::PyFunctionLibrary; bool IsControlFlowV1Op(Operation* op) { return mlir::isa extra_tf_opdefs) { +absl::Status RegisterExtraTfOpDefs( + absl::Span extra_tf_opdefs) { for (const auto& tf_opdefs_string : extra_tf_opdefs) { - tensorflow::OpDef opdef; - if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, - &opdef)) { + OpDef opdef; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + if (!tsl::protobuf::TextFormat::ParseFromString(tf_opdefs_string, &opdef)) { LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; - return errors::InvalidArgument("fail to parse extra OpDef"); + return absl::InvalidArgumentError("fail to parse extra OpDef"); } // Register extra opdefs. - // TODO(b/133770952): Support shape functions. - tensorflow::OpRegistry::Global()->Register( - [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { - *op_reg_data = tensorflow::OpRegistrationData(opdef); - return OkStatus(); + // TODO: b/133770952 - Support shape functions. + OpRegistry::Global()->Register( + [opdef](OpRegistrationData* op_reg_data) -> absl::Status { + *op_reg_data = OpRegistrationData(opdef); + return absl::OkStatus(); }); } - return OkStatus(); + return absl::OkStatus(); +} + +// The hlo->tf conversion is done in three steps; pre-quantization, +// quantization, and post-quantization. Quantization is optional, enabled only +// when `pass_config.enable_stablehlo_quantizer` is `true`. If quantization is +// not run, it only performs the hlo->tf conversion. +// +// All parameters except for `pass_config`, `pass_manager`, `status_handler`, +// and `module` are only required for quantization. See the comments of +// `RunQuantization` for details. If quantization is not performed, they will be +// ignored. +// +// Returns a failure status when any of the three steps fail. `pass_manager` +// will be cleared before returning. +mlir::LogicalResult RunHloToTfConversion( + const mlir::TFL::PassConfig& pass_config, + const absl::string_view saved_model_dir, + const std::unordered_set& saved_model_tags, + const QuantizationConfig& quantization_config, + const PyFunctionLibrary* quantization_py_function_lib, + const SavedModelBundle* saved_model_bundle, mlir::PassManager& pass_manager, + mlir::StatusScopedDiagnosticHandler& status_handler, ModuleOp& module) { + // TODO: b/194747383 - We need to valid that indeed the "main" func is + // presented. + AddPreQuantizationStableHloToTfPasses(/*entry_function_name=*/"main", + pass_config, pass_manager); + if (failed(pass_manager.run(module))) { + return mlir::failure(); + } + pass_manager.clear(); + + if (pass_config.enable_stablehlo_quantizer) { + const absl::StatusOr quantized_module_op = RunQuantization( + saved_model_bundle, saved_model_dir, saved_model_tags, + quantization_config, quantization_py_function_lib, module); + if (!quantized_module_op.ok()) { + LOG(ERROR) << "Failed to run quantization: " + << quantized_module_op.status(); + return mlir::failure(); + } + module = *quantized_module_op; + } + + AddPostQuantizationStableHloToTfPasses(pass_config, pass_manager); + if (failed(pass_manager.run(module))) { + return mlir::failure(); + } + pass_manager.clear(); + + return mlir::success(); } + } // namespace -StatusOr> LoadFromGraphdefOrMlirSource( +absl::StatusOr> LoadFromGraphdefOrMlirSource( const std::string& input_filename, bool input_mlir, bool use_splatted_constant, const std::vector& extra_tf_opdefs, const GraphImportConfig& specs, absl::string_view debug_info_file, @@ -156,8 +224,8 @@ StatusOr> LoadFromGraphdefOrMlirSource( std::string error_message; auto file = mlir::openInputFile(input_filename, &error_message); if (!file) { - llvm::errs() << error_message << "\n"; - return errors::InvalidArgument("fail to open input file"); + return absl::InvalidArgumentError( + absl::StrCat("Failed to open input file: ", error_message)); } if (input_mlir) { @@ -170,7 +238,7 @@ StatusOr> LoadFromGraphdefOrMlirSource( auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); if (!extra_opdefs_status.ok()) return extra_opdefs_status; - ::tensorflow::GraphdefToMlirOptions graphdef_conversion_options{ + GraphdefToMlirOptions graphdef_conversion_options{ std::string(debug_info_file), /*xla_compile_device_type=*/"", /*prune_unused_nodes=*/specs.prune_unused_nodes, @@ -182,21 +250,21 @@ StatusOr> LoadFromGraphdefOrMlirSource( /*enable_soft_placement=*/false}; if (use_splatted_constant) { - return tensorflow::GraphdefToSplattedMlirTranslateFunction( + return GraphdefToSplattedMlirTranslateFunction( file->getBuffer(), input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, graphdef_conversion_options, context); } - return tensorflow::GraphdefToMlirTranslateFunction( - file->getBuffer(), input_arrays, input_dtypes, input_shapes, - output_arrays, control_output_arrays, graphdef_conversion_options, - context); + return GraphdefToMlirTranslateFunction(file->getBuffer(), input_arrays, + input_dtypes, input_shapes, + output_arrays, control_output_arrays, + graphdef_conversion_options, context); } // Applying post-training dynamic range quantization from the old TOCO quantizer // on the translated_result using quant_specs and saving the final output in // result. -Status ApplyDynamicRangeQuantizationFromOldQuantizer( +absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( const mlir::quant::QuantizationSpecs& quant_specs, std::string translated_result, std::string* result) { flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); @@ -206,14 +274,14 @@ Status ApplyDynamicRangeQuantizationFromOldQuantizer( ::tflite::optimize::BufferType quantized_type; switch (quant_specs.inference_type) { - case tensorflow::DT_QINT8: + case DT_QINT8: quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8; break; - case tensorflow::DT_HALF: + case DT_HALF: quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16; break; default: - return errors::InvalidArgument("Quantized type not supported"); + return absl::InvalidArgumentError("Quantized type not supported"); break; } @@ -221,59 +289,59 @@ Status ApplyDynamicRangeQuantizationFromOldQuantizer( if (::tflite::optimize::QuantizeWeights( &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, ::tflite::optimize::QuantizerType::OLD_QUANTIZER) != kTfLiteOk) { - return errors::InvalidArgument("Quantize weights transformation failed."); + return absl::InvalidArgumentError( + "Quantize weights transformation failed."); } const uint8_t* q_buffer = q_builder.GetBufferPointer(); *result = - string(reinterpret_cast(q_buffer), q_builder.GetSize()); + std::string(reinterpret_cast(q_buffer), q_builder.GetSize()); - return OkStatus(); + return absl::OkStatus(); } -Status ConvertTFExecutorToStablehloFlatbuffer( +absl::Status ConvertTFExecutorToStablehloFlatbuffer( mlir::PassManager& pass_manager, mlir::ModuleOp module, bool export_to_mlir, - mlir::StatusScopedDiagnosticHandler& statusHandler, + mlir::StatusScopedDiagnosticHandler& status_handler, const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, - std::optional session, std::string* result, + std::optional session, std::string* result, const std::unordered_set& saved_model_tags) { // Currently, TF quantization only support dynamic range quant, as such // when toco flag post training quantization is specified with converting to // stablehlo, we automatically enable dynamic range quantization if (toco_flags.post_training_quantize()) { - const auto status = tensorflow::quantization::PreprocessAndFreezeGraph( + const auto status = quantization::PreprocessAndFreezeGraph( module, module.getContext(), session); if (!status.ok()) { - return errors::Aborted("Failed to preprocess & freeze TF graph"); + return status_handler.Combine( + absl::InternalError("Failed to preprocess & freeze TF graph.")); } - // TODO(b/264218457): Refactor the component below once StableHLO Quantizer + // TODO: b/264218457 - Refactor the component below once StableHLO Quantizer // can run DRQ. Temporarily using TF Quantization for StableHLO DRQ. if (!toco_flags.has_quantization_options()) { // The default minimum number of elements a weights array must have to be // quantized by this transformation. const int kWeightsMinNumElementsDefault = 1024; - tensorflow::quantization::QuantizationOptions quantization_options; + quantization::QuantizationOptions quantization_options; quantization_options.mutable_quantization_method()->set_preset_method( - tensorflow::quantization::QuantizationMethod:: - METHOD_DYNAMIC_RANGE_INT8); - quantization_options.set_op_set( - tensorflow::quantization::UNIFORM_QUANTIZED); + quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8); + quantization_options.set_op_set(quantization::UNIFORM_QUANTIZED); quantization_options.set_min_num_elements_for_weights( kWeightsMinNumElementsDefault); - tensorflow::quantization::AddQuantizePtqDynamicRangePasses( - pass_manager, quantization_options); + quantization::AddQuantizePtqDynamicRangePasses(pass_manager, + quantization_options); } if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } } pass_manager.clear(); - mlir::odml::AddTFToStablehloPasses(pass_manager, /*skip_resize*/ true, - /*smuggle_disallowed_ops*/ true); + mlir::odml::AddTFToStablehloPasses(pass_manager, /*skip_resize=*/true, + /*smuggle_disallowed_ops=*/true); // Print out a detailed report of non-converted stats. pass_manager.addPass(mlir::odml::createPrintOpStatsPass( mlir::odml::GetAcceptedStableHLODialects())); @@ -283,13 +351,13 @@ Status ConvertTFExecutorToStablehloFlatbuffer( pass_manager, toco_flags.quantization_options()); } if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } if (export_to_mlir) { llvm::raw_string_ostream os(*result); module.print(os); - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } // Write MLIR Stablehlo dialect into FlatBuffer @@ -301,24 +369,20 @@ Status ConvertTFExecutorToStablehloFlatbuffer( options.metadata[tflite::kModelUseStablehloTensorKey] = "true"; if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result, true)) { - auto s = statusHandler.ConsumeStatus(); - std::string message = "Could not translate MLIR to FlatBuffer."; - if (!s.ok()) { - absl::StrAppend(&message, " ", s.ToString()); - } - return absl::UnknownError(message); + return status_handler.Combine( + absl::InternalError("Could not translate MLIR to FlatBuffer.")); } - return OkStatus(); + return absl::OkStatus(); } -Status ConvertTFExecutorToTFLOrFlatbuffer( +absl::Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, - llvm::StringRef saved_model_dir, - std::optional session, std::string* result, - bool serialize_stablehlo_ops) { + llvm::StringRef saved_model_dir, SavedModelBundle* saved_model_bundle, + std::string* result, bool serialize_stablehlo_ops, + const PyFunctionLibrary* quantization_py_function_lib) { // Explicitly disable dumping Op details on failures. module.getContext()->printOpOnDiagnostic(false); @@ -326,92 +390,77 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::func::registerAllExtensions(registry); module.getContext()->appendDialectRegistry(registry); - // Register a warning handler only log to std out. - mlir::ScopedDiagnosticHandler s( - module.getContext(), [](mlir::Diagnostic& diag) { - if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) { - for (auto& note : diag.getNotes()) { - std::cout << note.str() << "\n"; - LOG(WARNING) << note.str() << "\n"; - } - } - return mlir::failure(); - }); - - mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), - /*propagate=*/true); - - if (failed(IsValidGraph(module))) { - return statusHandler.ConsumeStatus(); - } + mlir::StatusScopedDiagnosticHandler status_handler(module.getContext(), + /*propagate=*/true); mlir::PassManager pass_manager(module.getContext()); mlir::registerPassManagerCLOptions(); if (mlir::failed(mlir::applyPassManagerCLOptions(pass_manager))) { - return absl::UnknownError("failed to apply MLIR pass manager CL options"); + return absl::InternalError("Failed to apply MLIR pass manager CL options."); } + InitPassManager(pass_manager, toco_flags.debug_options()); + pass_manager.addInstrumentation( std::make_unique( pass_manager.getContext())); - InitPassManager(pass_manager, toco_flags.debug_options()); + if (failed(IsValidGraph(module))) { + return status_handler.ConsumeStatus(); + } + + Session* session = saved_model_bundle == nullptr + ? nullptr + : saved_model_bundle->GetSession(); if (pass_config.enable_stablehlo_conversion) { + // `ConvertTFExecutorToStablehloFlatbuffer` expects a `std::nullopt` if the + // `Session*` is a nullptr. + std::optional session_opt = + session == nullptr ? std::nullopt : std::make_optional(session); + // return to avoid adding TFL converter path return ConvertTFExecutorToStablehloFlatbuffer( - pass_manager, module, export_to_mlir, statusHandler, toco_flags, - pass_config, session, result, saved_model_tags); + pass_manager, module, export_to_mlir, status_handler, toco_flags, + pass_config, std::move(session_opt), result, saved_model_tags); } - tensorflow::AddPreVariableFreezingTFToTFLConversionPasses(pass_config, - &pass_manager); + if (pass_config.enable_hlo_to_tf_conversion) { + if (failed(RunHloToTfConversion( + pass_config, saved_model_dir, saved_model_tags, + toco_flags.quantization_config(), quantization_py_function_lib, + saved_model_bundle, pass_manager, status_handler, module))) { + return status_handler.ConsumeStatus(); + } + } + + AddPreVariableFreezingTFToTFLConversionPasses(pass_config, &pass_manager); if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } + // Freeze variables if a session is provided. - if (session.has_value()) { - mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext()); - if (failed( - mlir::tf_saved_model::FreezeVariables(module, session.value()))) { - auto status = statusHandler.ConsumeStatus(); - mlir::TFL::ErrorCollector* collector = - mlir::TFL::ErrorCollector::GetErrorCollector(); - if (!collector->CollectedErrors().empty()) { - // LINT.IfChange - return errors::InvalidArgument( - "Variable constant folding is failed. Please consider using " - "enabling `experimental_enable_resource_variables` flag in the " - "TFLite converter object. For example, " - "converter.experimental_enable_resource_variables = True"); - // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py) - } - return status; - } + if (session != nullptr && + failed(mlir::tf_saved_model::FreezeVariables(module, session))) { + return status_handler.Combine(absl::InvalidArgumentError( + "Variable constant folding is failed. Please consider using " + "enabling `experimental_enable_resource_variables` flag in the " + "TFLite converter object. For example, " + "converter.experimental_enable_resource_variables = True")); } pass_manager.clear(); - tensorflow::AddPostVariableFreezingTFToTFLConversionPasses( - saved_model_dir, toco_flags, pass_config, &pass_manager); + AddPostVariableFreezingTFToTFLConversionPasses(saved_model_dir, toco_flags, + pass_config, &pass_manager); if (failed(pass_manager.run(module))) { - auto status = statusHandler.ConsumeStatus(); - mlir::TFL::ErrorCollector* collector = - mlir::TFL::ErrorCollector::GetErrorCollector(); - for (const auto& error_data : collector->CollectedErrors()) { - if (error_data.subcomponent() == "FreezeGlobalTensorsPass") { - // LINT.IfChange - return errors::InvalidArgument( - "Variable constant folding is failed. Please consider using " - "enabling `experimental_enable_resource_variables` flag in the " - "TFLite converter object. For example, " - "converter.experimental_enable_resource_variables = True"); - // LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py) - } - } - return status; + return status_handler.Combine(absl::InvalidArgumentError( + "Variable constant folding is failed. Please consider using " + "enabling `experimental_enable_resource_variables` flag in the " + "TFLite converter object. For example, " + "converter.experimental_enable_resource_variables = True")); } if (failed(GraphContainsStatefulPartitionedOp(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } if (export_to_mlir) { @@ -420,12 +469,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( pass_manager.addPass(mlir::odml::createPrintOpStatsPass( mlir::odml::GetAcceptedTFLiteDialects())); if (failed(pass_manager.run(module))) { - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } llvm::raw_string_ostream os(*result); module.print(os); - return statusHandler.ConsumeStatus(); + return status_handler.ConsumeStatus(); } // Write MLIR TFLite dialect into FlatBuffer @@ -443,15 +492,11 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( } if (!tflite::MlirToFlatBufferTranslateFunction( module, options, &translated_result, serialize_stablehlo_ops)) { - auto s = statusHandler.ConsumeStatus(); - std::string message = "Could not translate MLIR to FlatBuffer."; - if (!s.ok()) { - absl::StrAppend(&message, " ", s.ToString()); - } - return absl::UnknownError(message); + return status_handler.Combine( + absl::InternalError("Could not translate MLIR to FlatBuffer.")); } - // TODO(b/176267167): Quantize flex fallback in the MLIR pipeline + // TODO: b/176267167 - Quantize flex fallback in the MLIR pipeline if (quant_specs.weight_quantization && (!quant_specs.RunAndRewriteDynamicRangeQuantizationPasses() || !pass_config.emit_builtin_tflite_ops)) { @@ -460,30 +505,33 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( // statement. auto status = ApplyDynamicRangeQuantizationFromOldQuantizer( quant_specs, translated_result, result); - if (!status.ok()) return status; + if (!status.ok()) { + return status_handler.Combine(status); + } } else { *result = translated_result; } if (mlir::failed(module.verifyInvariants())) { - return tensorflow::errors::Unknown("Final module is invalid"); + return status_handler.Combine( + absl::InternalError("Final module is invalid.")); } - return OkStatus(); + return absl::OkStatus(); } -StatusOr> ImportSavedModel( +absl::StatusOr> ImportSavedModel( const std::string& input_filename, const int saved_model_version, const std::unordered_set& tags, absl::Span extra_tf_opdefs, absl::Span exported_names, const GraphImportConfig& specs, bool enable_variable_lifting, mlir::MLIRContext* context, - std::unique_ptr* saved_model_bundle) { + std::unique_ptr* saved_model_bundle) { // Register extra TF ops passed as OpDef. auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); if (!extra_opdefs_status.ok()) return extra_opdefs_status; if (saved_model_version == 2) { - auto module_or = tensorflow::SavedModelObjectGraphToMlirImport( + auto module_or = SavedModelObjectGraphToMlirImport( input_filename, tags, exported_names, context, /*unconditionally_use_set_output_shapes=*/true); if (!module_or.status().ok()) return module_or.status(); @@ -493,15 +541,14 @@ StatusOr> ImportSavedModel( options.upgrade_legacy = specs.upgrade_legacy; options.unconditionally_use_set_output_shapes = true; options.lift_variables = enable_variable_lifting; - auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( + auto module_or = SavedModelSignatureDefsToMlirImport( input_filename, tags, exported_names, context, options, saved_model_bundle); if (!module_or.status().ok()) return module_or.status(); return std::move(module_or).value(); } else { - return tensorflow::errors::InvalidArgument( - "Should be either saved model v1 or v2"); + return absl::InvalidArgumentError("Should be either saved model v1 or v2."); } } diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 8ef8813c6c19fc..a82afbeee7eda6 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -22,15 +22,17 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/SourceMgr.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -84,9 +86,10 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, - llvm::StringRef saved_model_dir, - std::optional session, std::string* result, - bool serialize_stablehlo_ops = false); + llvm::StringRef saved_model_dir, SavedModelBundle* saved_model_bundle, + std::string* result, bool serialize_stablehlo_ops = false, + const quantization::PyFunctionLibrary* quantization_py_function_lib = + nullptr); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index bf4224c7631dd0..f620995ea2ecfd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -14,25 +14,24 @@ limitations under the License. ==============================================================================*/ #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "absl/memory/memory.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" -#include "mlir/IR/Location.h" // from @llvm-project +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" +#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/utils/utils.h" //===----------------------------------------------------------------------===// // The Pass to add default quantization parameters for the activations which @@ -215,7 +214,8 @@ quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias( // The non-bias hasn't been quantized, let's skip this bias. if (non_bias_types.size() != non_biases.size()) return {}; - return func(non_bias_types, false); + return func(/*op_types=*/non_bias_types, /*adjusted_quant_dim=*/-1, + /*legacy_float_scale=*/false); } quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams( diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index a2ea10fe199736..cfe9bc754d8077 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -70,6 +70,9 @@ def ExtractSingleElementAsInt32 : NativeCodeCall< def CreateTFCastToInt32Op : NativeCodeCall< "CreateCastToInt32($0, $_loc, $_builder)">; +def CreateInt32ConstOrCast : NativeCodeCall< + "CreateInt32ConstOrCast($0, $_loc, $_builder)">; + def CreateNoneValue : NativeCodeCall< "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; @@ -587,10 +590,7 @@ def LegalizeCumsum : Pat< def LegalizeReshape : Pat< (TF_ReshapeOp $input, $shape), - (TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>; - -def ZeroIntAttr - : AttrConstraint().getInt() == 0">>; + (TFL_ReshapeOp $input, (CreateInt32ConstOrCast $shape))>; def LegalizeStridedSlice : Pat< (TF_StridedSliceOp diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 1ae84c64ddea7f..cc4e7e46b71b99 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -33,13 +33,17 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -106,6 +110,34 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { rewriter.getBoolAttr(false)); } +// Utility function to- +// 1. Create a tfl.const op with an int32_t values, from an MLIR Value, if the +// `Value` can be matched to a Constant DenseIntElementsAttr. +// This will make sure the dynamic dimensions are asigned to be `-1` +// 2. In the default case, cast the `Value` to an int32_t. +Value CreateInt32ConstOrCast(Value val, Location loc, + PatternRewriter& rewriter) { + if (val.getType().cast().hasStaticShape()) { + DenseElementsAttr shape_value_attr; + if (matchPattern(val, m_Constant(&shape_value_attr))) { + SmallVector new_shape_array_i32; + auto shape_value_array = shape_value_attr.getValues(); + for (int32_t idx = 0; idx < shape_value_array.size(); ++idx) { + auto size = shape_value_array[idx].getSExtValue(); + new_shape_array_i32.push_back( + ShapedType::isDynamic(size) ? -1 : static_cast(size)); + } + return rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get(new_shape_array_i32.size(), + rewriter.getIntegerType(32)), + new_shape_array_i32)); + } + } + + return CreateCastToInt32(val, loc, rewriter); +} + // Get shape of an operand or result, support both dynamic and static shape. Value GetShape(Value input, Location loc, PatternRewriter& rewriter) { auto shaped_type = input.getType().cast(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index ec416dc3ea7939..13703233f6259f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -60,6 +60,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" @@ -704,6 +705,252 @@ bool IsPermutationNCHW(Value perm) { #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" +// Returns 1D 32-bit dense elements attribute with the given values. +static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +// Get the number of leading 1s in the shape of the given input. +// Ex. input_shape = [1 x 1 x 1 x 1 x 2 x 1] => 4 +// returns 0 if the input shape is not static. +int GetNumLeadingOnes(ShapedType input_type) { + if (!input_type.hasStaticShape()) return 0; + auto input_shape = input_type.getShape(); + int num_leading_broadcast_dims = 0; + for (int i = 0; i < input_shape.size(); ++i) { + if (input_shape[i] == 1) { + ++num_leading_broadcast_dims; + } else { + break; + } + } + return num_leading_broadcast_dims; +} + +// Return the number of trailing 1s in the shape of the given input. +// Ex. input_shape = [1 x 1 x 2 x 1] => 1 +// returns 0 if the input shape is not static. +int GetNumTrailingOnes(ShapedType input_type) { + if (!input_type.hasStaticShape()) return 0; + auto input_shape = input_type.getShape(); + int num_trailing_broadcast_dims = 0; + for (int i = input_shape.size() - 1; i >= 0; --i) { + if (input_shape[i] == 1) { + ++num_trailing_broadcast_dims; + } else { + break; + } + } + return num_trailing_broadcast_dims; +} + +// Consider as Reshape( +// Broadcast( +// Reshape(input, // input_shape=[1 x n] +// inner_shape), // inner_shape=[1 x 1 x 1 x n x 1 x 1] +// broadcast_shape), // broadcast_shape=[1 x 1 x 1 x n x m x 1] +// outer_shape))) // outer_shape=[1 x 1 x n*m] +// Here the broadcast operation is used to create `m` repetetions of the `n` +// elements in the origiginal tensor, making a total of `m*n` number of elements +// in the final tensor that will then be reshaped to form something like +// [1 x 1 x 1 x m*n] by the outermost reshape_op. +// problem: The inefficiency here is that the innermost reshape_op and the +// broadcast_op are introducing unnecessary leading and trailing 1s'. +// fix: Remove the unnecessary 1s' in the inner reshape_op and broadcast_op. +struct SqueezeReshapesAroundBroadcastOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, + PatternRewriter &rewriter) const override { + auto loc = tfl_broadcast_to_op->getLoc(); + + // Match the + // Reshape( + // Broadcast( + // Reshape(input,inner_shape), + // broadcast_shape), + // outer_shape))) pattern. + if (!llvm::dyn_cast_or_null( + tfl_broadcast_to_op.getInput().getDefiningOp()) || + // Check that the broadcast_to op has only one use. + !tfl_broadcast_to_op.getOutput().hasOneUse() || + !llvm::dyn_cast_or_null( + *tfl_broadcast_to_op.getOutput().getUsers().begin())) { + return rewriter.notifyMatchFailure( + loc, "No Reshape->BroadcastTo->Reshape pattern found"); + } + + // Pattern is applied only if the broadcast_to shape has more than 5 + // dimensions. + if (tfl_broadcast_to_op.getShape() + .getType() + .cast() + .getNumElements() < 6) { + return rewriter.notifyMatchFailure(loc, + "Not supported broadcast_to shape"); + } + auto inner_reshape_op = llvm::dyn_cast_or_null( + tfl_broadcast_to_op.getInput().getDefiningOp()); + auto inner_reshape_input = inner_reshape_op.getInput(); + auto outer_reshape_op = llvm::dyn_cast_or_null( + *tfl_broadcast_to_op.getOutput().getUsers().begin()); + + // Check that the outermost reshape_op in the pattern does not add + // additional elements to the final output tensor. + // TODO: b/323217483. This code needs to generalized to additional cases. + // For example- inner-shape = [1, 1, 1, 8, 1, 10], + // broadcast_shape = [1, 1, 1, 8, 16, 10] & outer_shape = [1, 1, 1, 1280, 1] + // And extend the pettern to handle dynamic shapes. + if (!inner_reshape_op.getOutput().getType().hasStaticShape() || + !tfl_broadcast_to_op.getOutput().getType().hasStaticShape() || + !outer_reshape_op.getOutput().getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "Unsupported shapes. Currely only static shapes are supported"); + } + + if (!IsLastDimEqualToNumElements(inner_reshape_input.getType(), + inner_reshape_op.getOutput().getType()) || + !IsLastDimEqualToNumElements( + outer_reshape_op.getOutput().getType(), + tfl_broadcast_to_op.getOutput().getType())) { + return rewriter.notifyMatchFailure( + loc, "Not supported Reshape->BroadcastTo->Reshape pattern"); + } + + // Calculate the number of extra leading and trailing 1s in the + // broadcast_op output. + auto broadcast_output_shapetype = + tfl_broadcast_to_op.getOutput().getType().cast(); + int num_leading_broadcast_dims = + GetNumLeadingOnes(broadcast_output_shapetype); + int num_trailing_broadcast_dims = + GetNumTrailingOnes(broadcast_output_shapetype); + + // Get the new shape for the inner reshape_op after removing the extra 1s. + llvm::SmallVector new_reshape_shape_i32{ + inner_reshape_op.getOutput() + .getType() + .cast() + .getShape() + .drop_back(num_trailing_broadcast_dims) + .drop_front(num_leading_broadcast_dims)}; + + Value new_reshape_shape_value = rewriter.create( + inner_reshape_op->getLoc(), + GetI32ElementsAttr(new_reshape_shape_i32, &rewriter)); + + auto new_inner_reshape_op = rewriter.create( + inner_reshape_op->getLoc(), + inner_reshape_input, new_reshape_shape_value); + + // Create a new reshape_op to replace the old inner reshape_op. + rewriter.replaceOp(inner_reshape_op, new_inner_reshape_op.getResult()); + + // Get the new shape for the broadcast_op after removing the extra 1s. + llvm::SmallVector new_broadcast_shape{ + broadcast_output_shapetype.getShape() + .drop_back(num_trailing_broadcast_dims) + .drop_front(num_leading_broadcast_dims)}; + + Value new_broadcast_shape_value = rewriter.create( + loc, GetI64ElementsAttr(new_broadcast_shape, &rewriter)); + + auto new_broadcast_to_op = rewriter.create( + loc, RankedTensorType::get(new_broadcast_shape, rewriter.getF32Type()), + new_inner_reshape_op.getOutput(), new_broadcast_shape_value); + + // Create a new broadcast_op to replace the old broadcast_op. + rewriter.replaceOp(tfl_broadcast_to_op, new_broadcast_to_op.getResult()); + + return success(); + } +}; + +// This pattern matches TFL::BroadcastToOp WITH TENSOR RANK <= 4 and replaces +// it with a MulOp that multiplies the tensor by a splat constant with 1s. +struct ConvertTFLBroadcastToMulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, + PatternRewriter &rewriter) const override { + auto input_type = + tfl_broadcast_to_op.getInput().getType().cast(); + auto output_type = + tfl_broadcast_to_op.getOutput().getType().cast(); + auto shape_type = + tfl_broadcast_to_op.getShape().getType().cast(); + Type element_type = input_type.getElementType(); + + auto loc = tfl_broadcast_to_op->getLoc(); + + // Check that the output type is not dynamic and is less-than-equal to 4D or + // the shape type is static, 1D and has less-than-equal to 4 elements. + bool is_output_shape_dynamic = + (!output_type.hasRank() || (output_type.getRank() > 4) || + (output_type.getNumDynamicDims() > 0)); + bool is_broadcast_shape_dynamic = + (!shape_type.hasStaticShape() || (shape_type.getRank() != 1) || + (shape_type.getDimSize(0) > 4)); + if (is_output_shape_dynamic && is_broadcast_shape_dynamic) + return rewriter.notifyMatchFailure( + loc, "output_rank or broadcast_to shape not supported"); + + // Allow lowering when the input's elements type is F32, BFloat16, I32 or + // I16. + if (!(element_type.isa() || + element_type.isInteger(32) || element_type.isInteger(16))) + return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); + + // TFL_FillOp is created only if is_output_shape_dynamic is true, otherwise + // a Arith.ConstOp is created. + if (is_output_shape_dynamic && + output_type.getElementType().isUnsignedInteger()) { + return rewriter.notifyMatchFailure( + loc, + "Unsigned broadcast_to output with dynamic shape is not supported"); + } + + Value mul_rhs_value; + if (!output_type.hasRank() || (output_type.getNumDynamicDims() > 0)) { + auto status_or_const_op = + CreateConstOpWithSingleValue(&rewriter, loc, input_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + mul_rhs_value = rewriter.create( + loc, output_type, tfl_broadcast_to_op.getShape(), + status_or_const_op.value()); + } else { + auto status_or_const_op = + CreateConstOpWithVectorValue(&rewriter, loc, output_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + mul_rhs_value = status_or_const_op.value(); + } + + auto mul_op = rewriter.create( + loc, output_type, tfl_broadcast_to_op.getInput(), mul_rhs_value, + rewriter.getStringAttr("NONE")); + rewriter.replaceOp(tfl_broadcast_to_op, mul_op.getResult()); + return success(); + } +}; + struct FuseAddAndStridedSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -886,14 +1133,9 @@ struct Convert2DUpscalingToResizeNearestNeighor SmallVector reshape_shape_in_int64( {1, image_size, image_size, feature_size}); - auto reshape_shape_type = - RankedTensorType::get({static_cast(reshape_shape.size())}, - rewriter.getIntegerType(32)); - auto reshape_shape_attr = - DenseIntElementsAttr::get(reshape_shape_type, reshape_shape); - auto reshape_shape_const_op = rewriter.create( - gather_nd_first->getLoc(), reshape_shape_attr); + gather_nd_first->getLoc(), + GetI32ElementsAttr(reshape_shape, &rewriter)); auto reshape_op = rewriter.create( gather_nd_first->getLoc(), @@ -903,12 +1145,8 @@ struct Convert2DUpscalingToResizeNearestNeighor // Add TFL::resize_nearest_neighor op for 2x upscaling. SmallVector size_vec = {image_size * 2, image_size * 2}; - auto size_type = mlir::RankedTensorType::get( - {static_cast(size_vec.size())}, rewriter.getIntegerType(32)); - auto size_attr = mlir::DenseIntElementsAttr::get(size_type, size_vec); - - auto size_const_op = - rewriter.create(gather_nd_first->getLoc(), size_attr); + auto size_const_op = rewriter.create( + gather_nd_first->getLoc(), GetI32ElementsAttr(size_vec, &rewriter)); auto resize = rewriter.create( gather_nd_first->getLoc(), transpose_second.getResult().getType(), @@ -1249,6 +1487,12 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { auto fc_op = dyn_cast_or_null( mul_op.getLhs().getDefiningOp()); if (!fc_op) return failure(); + + // Check if FullyConnected has only one use, that is the LHS of Mul Op. + // Otherwise this will duplicate the fullyconnected op to serve the + // remaining uses. + if (!fc_op->hasOneUse()) return failure(); + Value filter = fc_op.getFilter(); Value bias = fc_op.getBias(); ElementsAttr cst_tmp; @@ -1759,11 +2003,9 @@ struct ConvertTrivialTransposeOpToReshapeOp output_shape_values.push_back( ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); } - auto type = mlir::RankedTensorType::get(output_shape_values.size(), - rewriter.getIntegerType(32)); - auto new_shape_attr = - mlir::DenseIntElementsAttr::get(type, output_shape_values); - auto new_shape = rewriter.create(loc, new_shape_attr); + + auto new_shape = rewriter.create( + loc, GetI32ElementsAttr(output_shape_values, &rewriter)); rewriter.replaceOpWithNewOp( transpose_op, transpose_op.getOutput().getType(), @@ -1932,11 +2174,7 @@ struct FuseUnpackAndConcatToReshape ShapedType::isDynamic(size) ? -1 : static_cast(size)); } auto new_shape = rewriter.create( - concat_op.getLoc(), - DenseIntElementsAttr::get( - RankedTensorType::get(new_shape_array_i32.size(), - rewriter.getIntegerType(32)), - new_shape_array_i32)); + concat_op.getLoc(), GetI32ElementsAttr(new_shape_array_i32, &rewriter)); rewriter.replaceOpWithNewOp( concat_op, output_type, unpack_op.getInput(), new_shape); @@ -2126,9 +2364,7 @@ struct FuseReshapeAndTransposeAroundBatchMatmul transpose_input.getType().getShape().begin() + 2, transpose_input.getType().getShape().end(), 1, std::multiplies()))}; auto shape_constant = rewriter.create( - batch_matmul.getLoc(), - DenseIntElementsAttr::get( - RankedTensorType::get(3, rewriter.getI32Type()), new_shape)); + batch_matmul.getLoc(), GetI32ElementsAttr(new_shape, &rewriter)); auto reshaped_input = rewriter.create( batch_matmul.getLoc(), transpose_op.getInput(), shape_constant); rewriter.replaceOpWithNewOp( @@ -2190,10 +2426,7 @@ struct FuseTransposeReshapeIntoBatchMatmul reshape_op.getType().getShape().drop_front().end()); new_shape.push_back(reshape_op.getType().getDimSize(0)); auto shape_constant = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get( - RankedTensorType::get(reshape_op.getType().getRank(), - rewriter.getI32Type()), - new_shape)); + op.getLoc(), GetI32ElementsAttr(new_shape, &rewriter)); auto new_reshape = rewriter.create( op.getLoc(), transpose_op.getInput(), shape_constant); rewriter.replaceOpWithNewOp( @@ -2421,8 +2654,8 @@ void OptimizePass::runOnOperation() { // binary ops. RewritePatternSet phase_0_patterns(&getContext()); phase_0_patterns - .add( - ctx); + .add(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns)); // Potentially the binary ops might be fused together, like hard_swish, thus diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 114a01492ff16b..008decb62b0d55 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1129,6 +1129,26 @@ def OptimizeSliceOp : Pat< (replaceWithValue $input), [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>; +// Convert the StridedSliceOp to a SliceOp when possible. This will enable other +// optimizations on SliceOp to run. +def OptimizeStridedSlice : Pat< + (TFL_StridedSliceOp $input, + (Arith_ConstantOp $begin), + (Arith_ConstantOp $end), + (Arith_ConstantOp $stride), + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ZeroIntAttr:$_, + ConstBoolAttrFalse), + (TFL_SliceOp $input, + (Arith_ConstantOp $begin), + (Arith_ConstantOp (GetOffSet $begin, $end))), + [(IsAllOnesConstant $stride), + (HasNonNegativeValues $begin), + (HasNonNegativeOffset $begin, $end)]>; + def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0.getType())">; def ReshapeValueDroppingLastDim : NativeCodeCall< @@ -1510,14 +1530,7 @@ def FuseReshapesAroundBatchMatMulLHS1: Pat< (BroadcastDimsProductEqual<1> $final_shape_change, $bmm_tmp_output), (AreTensorSubSectionShapesEqual<1, 1> $input, $final_shape_change)]>; -// Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp -def FuseTransposeIntoBatchMatMulRHS: Pat< - (TFL_BatchMatMulOp $lhs, - (TFL_TransposeOp:$transposed_value $input, (Arith_ConstantOp:$perm_value $p0)), - $adj_x, $adj_y, $asymmetric_quantize_inputs), - (TFL_BatchMatMulOp $lhs, $input, $adj_x, ConstBoolAttrTrue, $asymmetric_quantize_inputs), - [(AreLastTwoDimsTransposed $perm_value), - (IsBoolAttrEqual<"false"> $adj_y)]>; + // Fuse redundant TFL_TransposeOp into TFL_BatchMatMulOp def FuseTransposeIntoBatchMatMulLHS: Pat< diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 4dfea319b336d7..45428d2648a43f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -309,6 +309,8 @@ def PrepareQuantizePass : Pass<"tfl-prepare-quantize", "mlir::func::FuncOp"> { "disable-set-input-nodes-quantization-params", "bool", "false", "Whether disable set input nodes quantization parameters.">, + Option<"is_qdq_conversion_", "is-qdq-conversion", "bool", "false", + "Whether the source graph is a QDQ model intended for conversion only.">, ]; } @@ -323,6 +325,9 @@ def PrepareDynamicRangeQuantizePass : Pass<"tfl-prepare-quantize-dynamic-range", Option<"enable_dynamic_range_per_channel_quantization_", "enable-dynamic-range-per-channel-quantization", "bool", "true", "Whether enable per-channel quantized weights.">, + Option<"enable_dynamic_range_per_channel_quantization_for_dense_layers_", + "enable-dynamic-range-per-channel-quantization-for-dense-layers", "bool", + "true", "Whether enable per-channel quantized weights for Fully Connected layers (default is per tensor).">, Option<"min_elements_for_weights_", "min-elements-for-weights", "int64_t", "1024", "The minimum number of elements in a weights array required to apply quantization.">, diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 807d1c2dfaa2b5..fb613d74bbfaa2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ // This transformation pass applies quantization propagation on TFLite dialect. +#include #include +#include #include #include #include @@ -217,7 +219,9 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" bool PrepareQuantizePass::RemoveRedundantStats(func::FuncOp func) { - return RemoveRedundantStatsOps(func, GetOpQuantSpec); + return RemoveRedundantStatsOps( + func, std::bind(GetOpQuantSpec, std::placeholders::_1, + quant_specs_.disable_per_channel_for_dense_layers)); } static Value Quantized(Operation* user) { @@ -402,12 +406,21 @@ void PrepareQuantizePass::runOnOperation() { SanityCheckAndAdjustment(func); + // Bind the getter with the fixed configuration parameter for the correct + // quantization settings of the ops. + std::function(Operation*)> + op_quant_spec_getter = + std::bind(GetOpQuantSpec, std::placeholders::_1, + quant_specs_.disable_per_channel_for_dense_layers); + // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). ApplyQuantizationParamsPropagation( func, is_signed, bit_width, - disable_per_channel_ || quant_specs_.disable_per_channel, GetOpQuantSpec, - infer_tensor_range, quant_specs_.legacy_float_scale); + disable_per_channel_ || quant_specs_.disable_per_channel, + op_quant_spec_getter, infer_tensor_range, quant_specs_.legacy_float_scale, + (is_qdq_conversion_ || + quant_specs_.qdq_conversion_mode != quant::QDQConversionMode::kQDQNone)); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index 951748b31273f3..a60ebe57212f9e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -72,6 +72,8 @@ class PrepareDynamicRangeQuantizePass : quant_specs_(quant_specs) { enable_dynamic_range_per_channel_quantization_ = !quant_specs_.disable_per_channel; + enable_dynamic_range_per_channel_quantization_for_dense_layers_ = + !quant_specs_.disable_per_channel_for_dense_layers; min_elements_for_weights_ = quant_specs_.minimum_elements_for_weights; } @@ -275,6 +277,10 @@ class PrepareDynamicRangeQuantizableOp op_with_per_axis_support = op_with_narrow_range && affine_user.GetQuantizationDimIndex() != -1 && !quant_specs_.disable_per_channel; + if (dyn_cast(quantize_op)) { + op_with_per_axis_support &= + !quant_specs_.disable_per_channel_for_dense_layers; + } } QuantizedType quant_type = nullptr; @@ -473,6 +479,8 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() { quant_specs_.disable_per_channel = !enable_dynamic_range_per_channel_quantization_; + quant_specs_.disable_per_channel_for_dense_layers = + !enable_dynamic_range_per_channel_quantization_for_dense_layers_; quant_specs_.minimum_elements_for_weights = min_elements_for_weights_; if (!enable_custom_op_quantization_.empty()) { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index 90a48d577ef669..216c6756ab67db 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -505,9 +505,10 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( double scale) { return [=](const std::vector& quant_params, - bool legacy_float_scale) -> quant::QuantParams { - if (auto qtype = quant::GetUniformQuantizedTypeForBias(quant_params, - legacy_float_scale) + const int adjusted_quant_dim, + const bool legacy_float_scale) -> quant::QuantParams { + if (auto qtype = quant::GetUniformQuantizedTypeForBias( + quant_params, legacy_float_scale, adjusted_quant_dim) .dyn_cast_or_null()) { return quant::UniformQuantizedType::get( qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index c80be89c567e09..2e920595819f84 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -39,7 +39,6 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -780,48 +779,6 @@ struct ConvertTFStridedSlice : public RewritePattern { } }; -struct ConvertTFBroadcastTo : public RewritePattern { - explicit ConvertTFBroadcastTo(MLIRContext *context) - : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto tf_broadcast_to_op = cast(op); - auto input_type = - tf_broadcast_to_op.getInput().getType().cast(); - auto output_type = - tf_broadcast_to_op.getOutput().getType().cast(); - auto shape_type = - tf_broadcast_to_op.getShape().getType().cast(); - Type element_type = input_type.getElementType(); - - // Allow lowering when low dimension inputs are given and its type is F32 or - // I32. - if (!((output_type.hasRank() && output_type.getRank() <= 4) || - (shape_type.hasStaticShape() && shape_type.getRank() == 1 && - shape_type.getDimSize(0) <= 4))) - return failure(); - - if (!(element_type.isa() || - element_type.isInteger(32) || element_type.isInteger(16))) - return failure(); - - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - auto tf_fill_op = rewriter.create(op->getLoc(), output_type, - tf_broadcast_to_op.getShape(), - status_or_const_op.value()); - - auto mul_op = rewriter.create( - op->getLoc(), output_type, tf_broadcast_to_op.getInput(), tf_fill_op); - rewriter.replaceOp(op, mul_op.getResult()); - return success(); - } -}; // The below pattern is equivalent to the DRR rule below // The checks are dependent on generated values, so we can't add @@ -1591,9 +1548,8 @@ void PrepareTFPass::runOnOperation() { if (unfold_batch_matmul_) { TF::PopulateUnrollTfBatchMatMul(ctx, phase_2_patterns); } - phase_2_patterns - .add(ctx); + phase_2_patterns.add(ctx); phase_2_patterns.add( ctx, allow_bf16_and_f16_type_legalization_); // Remove redundant reshape ops. diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 0bc21e41f68ced..0d3494d851aaba 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -257,7 +257,9 @@ void QuantizePass::runOnOperation() { populateWithGenerated(patterns); - if (quant_specs.weight_quantization || quant_specs.use_fake_quant_num_bits) { + if (quant_specs.weight_quantization || quant_specs.use_fake_quant_num_bits || + quant_specs.qdq_conversion_mode == + quant::QDQConversionMode::kQDQDynamic) { patterns.add(ctx, quant_params); } else { patterns.add(ctx, diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 86d0509ceb7e65..0e5d10e9e7469a 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -15,39 +15,47 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" +#include +#include #include #include +#include "absl/status/status.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/status.h" +#include "tsl/platform/statusor.h" namespace mlir { namespace TFL { -tsl::StatusOr CreateConstOpWithSingleValue( - PatternRewriter* rewriter, Location loc, ShapedType shaped_type, - int value) { +tsl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { Type element_type = shaped_type.getElementType(); - ShapedType scalar_type = RankedTensorType::get({}, element_type); - TypedAttr attr; if (element_type.isF16()) { auto floatType = mlir::FloatType::getF16(element_type.getContext()); auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); + return DenseElementsAttr::get(shaped_type, floatValues); } else if (element_type.isBF16()) { auto floatType = mlir::FloatType::getBF16(element_type.getContext()); auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); + return DenseElementsAttr::get(shaped_type, floatValues); } else if (element_type.isF32()) { - attr = - DenseElementsAttr::get(scalar_type, static_cast(value)); + return DenseElementsAttr::get(shaped_type, + static_cast(value)); } else if (auto complex_type = element_type.dyn_cast()) { auto etype = complex_type.getElementType(); if (etype.isF32()) { @@ -64,7 +72,7 @@ tsl::StatusOr CreateConstOpWithSingleValue( repr.set_tensor_content(content); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); - attr = mlir::TF::TensorProtoAttr::get(scalar_type, mangled); + return mlir::TF::TensorProtoAttr::get(shaped_type, mangled); } else { return tensorflow::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); @@ -73,19 +81,19 @@ tsl::StatusOr CreateConstOpWithSingleValue( if (element_type.isSignedInteger()) { switch (itype.getWidth()) { case 8: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 16: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 32: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 64: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; default: @@ -95,19 +103,19 @@ tsl::StatusOr CreateConstOpWithSingleValue( } else { switch (itype.getWidth()) { case 8: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 16: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 32: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; case 64: - attr = DenseElementsAttr::get(scalar_type, + return DenseElementsAttr::get(shaped_type, static_cast(value)); break; default: @@ -119,8 +127,29 @@ tsl::StatusOr CreateConstOpWithSingleValue( return tensorflow::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } +} + +// Returns a Constant op with a splat vector value. +tsl::StatusOr CreateConstOpWithVectorValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, + int value) { + ShapedType dense_type = RankedTensorType::get(shaped_type.getShape(), + shaped_type.getElementType()); + auto attr = CreateTypedAttr(dense_type, value); + + return rewriter->create(loc, dense_type, + cast(*attr)); +} + +tsl::StatusOr CreateConstOpWithSingleValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, + int value) { + ShapedType scalar_type = + RankedTensorType::get({}, shaped_type.getElementType()); + auto attr = CreateTypedAttr(scalar_type, value); + return rewriter->create(loc, scalar_type, - cast(attr)); + cast(*attr)); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h index 1a71bd55a85e8a..f062e31a557d17 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.h @@ -31,6 +31,10 @@ namespace TFL { tsl::StatusOr CreateConstOpWithSingleValue( PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); +// Returns a Constant op with a splat vector value. +tsl::StatusOr CreateConstOpWithVectorValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); + } // namespace TFL } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 6130bab6531ba2..9fce1bc44387c3 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ +#include #include #include #include #include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -58,6 +62,44 @@ inline bool OpHasSameStaticShapes(Operation* op) { return true; } +// Checks if all elements in the constant attribute value are 1. +inline bool IsAllOnesConstant(Attribute value) { + auto values = value.cast().getValues(); + return !std::any_of(values.begin(), values.end(), + [](int32_t element_value) { return element_value != 1; }); +} + +// Checks if all elements in the constant attribute value are non-negative. +inline bool HasNonNegativeValues(Attribute value) { + auto values = value.cast().getValues(); + return !std::any_of( + values.begin(), values.end(), + [](const APInt& element_value) { return element_value.isNegative(); }); +} + +// Utility function to get the offset between two dense attribute values. +inline TypedAttr GetOffSet(Attribute begin, Attribute end) { + auto begin_values = begin.cast().getValues(); + auto end_values = end.cast().getValues(); + + SmallVector offsets; + if (begin_values.size() == end_values.size()) { + for (size_t i = 0; i < begin_values.size(); ++i) { + offsets.push_back(end_values[i] - begin_values[i]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get({static_cast(offsets.size())}, + mlir::IntegerType::get(begin.getContext(), 32)), + llvm::ArrayRef(offsets)); +} + +// Check if the offset between two dense attribute values is non-negative. +inline bool HasNonNegativeOffset(Attribute begin, Attribute end) { + return HasNonNegativeValues(GetOffSet(begin, end)); +} + // Return true if the permutation value only swaps the last two dimensions inline bool AreLastTwoDimsTransposed(Value permutation) { if (!permutation) return false; diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index e64b591ae78eda..42af8c67b2a7ce 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,23 @@ include "mlir/IR/PatternBase.td" // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; +// Constraint that values in list attribute are all ones. +def IsAllOnesConstant : Constraint>; + +// Constraint that checks if all values in offset between two +// attributes are non-negative. +def HasNonNegativeOffset : Constraint>; + +// Constraint that checks if all values in list attribute are non-negative. +def HasNonNegativeValues : Constraint>; + +// Utility function to get the offset between two dense attribute values. +def GetOffSet : NativeCodeCall<"TFL::GetOffSet($0, $1)">; + +// Attribute Constraint that checks if the attribute value is zero. +def ZeroIntAttr + : AttrConstraint().getInt() == 0">>; + // Checks if the value has rank at most 'n'. class HasRankAtLeast : Constraint< CPred<"$0.getType().cast().hasRank() && " diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index fd1ce2faa8e026..421b3df68642b2 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -52,6 +52,7 @@ tf_cc_test( name = "lift_as_function_call_test", srcs = ["lift_as_function_call_test.cc"], deps = [ + ":func", ":lift_as_function_call", ":test_base", "//tensorflow/compiler/mlir/tensorflow", @@ -65,6 +66,35 @@ tf_cc_test( ], ) +cc_library( + name = "func", + srcs = ["func.cc"], + hdrs = ["func.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/cc/saved_model:signature_constants", + "//tensorflow/compiler/mlir/tensorflow:import_model", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "func_test", + srcs = ["func_test.cc"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":func", + ":test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "test_base", testonly = 1, @@ -73,7 +103,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:context", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:test", @@ -101,9 +131,9 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", ], ) @@ -113,11 +143,12 @@ tf_cc_test( srcs = ["attrs_and_constraints_test.cc"], deps = [ ":attrs_and_constraints", + ":func", ":test_base", - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -139,3 +170,28 @@ td_library( "@llvm-project//mlir:FuncTdFiles", ], ) + +cc_library( + name = "uniform_quantized_types", + srcs = ["uniform_quantized_types.cc"], + hdrs = ["uniform_quantized_types.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "uniform_quantized_types_test", + srcs = ["uniform_quantized_types_test.cc"], + deps = [ + ":uniform_quantized_types", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index 9c3192d5345f28..1d2ccbdaaf4d2b 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -14,15 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include + #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" // IWYU pragma: keep namespace mlir::quant { @@ -59,4 +62,30 @@ SmallVector CloneOpWithReplacedOperands( return builder.clone(*op, mapping)->getResults(); } +FailureOr CastI64ToI32(const int64_t value) { + if (!llvm::isInt<32>(value)) { + DEBUG_WITH_TYPE( + "mlir-quant-attrs-and-constraints", + llvm::dbgs() + << "Tried to cast " << value + << "from int64 to int32, but lies out of range of int32.\n"); + return failure(); + } + return static_cast(value); +} + +FailureOr> CastI64ArrayToI32( + const ArrayRef int64_array) { + SmallVector int32_array{}; + int32_array.reserve(int64_array.size()); + + for (const int64_t i64 : int64_array) { + FailureOr cast_i32 = CastI64ToI32(i64); + if (failed(cast_i32)) return failure(); + + int32_array.push_back(*cast_i32); + } + return int32_array; +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index 8f298b56ec947e..e4c5e92294e221 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" namespace mlir::quant { @@ -150,6 +151,45 @@ FailureOr TryCast(Operation *op, const StringRef name) { } } +FailureOr CastI64ToI32(int64_t value); + +// Tries to cast an array of int64 to int32. If any of the element in the +// array is not in the range of int32, returns failure(). +FailureOr> CastI64ArrayToI32( + ArrayRef int64_array); + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation *FindUserOfType(Operation *op) { + for (Operation *user : op->getUsers()) { + if (isa(user)) { + return user; + } + } + return nullptr; +} + +// Returns the function attribute for the given call op which is lifted for +// quantization. +template +inline FlatSymbolRefAttr GetFuncAttr(LiftedOp call_op) { + static_assert(false, "DuplicateOp for call_op is not implemented."); +} + +template <> +inline FlatSymbolRefAttr GetFuncAttr( + TF::PartitionedCallOp call_op) { + return call_op.getFAttr().template dyn_cast(); +} + +template <> +inline FlatSymbolRefAttr GetFuncAttr( + TF::XlaCallModuleOp call_op) { + return call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); +} + } // namespace mlir::quant #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index afa7d44e1b595e..2c466b4415818b 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include + +#include #include #include "absl/strings/string_view.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -25,20 +29,26 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir::quant { namespace { using ::mlir::quant::QuantizationTestBase; using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::ConvolutionOp; using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::SubtractOp; +using ::testing::ElementsAreArray; +using ::testing::NotNull; class AttrsAndConstraintsTest : public QuantizationTestBase {}; constexpr absl::string_view kModuleStatic = R"mlir( module { - func.func private @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } @@ -47,16 +57,57 @@ constexpr absl::string_view kModuleStatic = R"mlir( constexpr absl::string_view kModuleDynamic = R"mlir( module { - func.func private @main(%arg0: tensor, %arg1: tensor<1024x3xf32>) -> tensor attributes {_from_xla_call_module} { + func.func @main(%arg0: tensor, %arg1: tensor<1024x3xf32>) -> tensor attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor, tensor<1024x3xf32>) -> tensor return %0 : tensor } } )mlir"; +constexpr absl::string_view kModuleMultipleUses = R"mlir( + module { + func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.subtract %0, %arg2 : tensor<1x3xf32> + %2 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + } +)mlir"; + +constexpr absl::string_view kModuleXlaCallModule = R"mlir( + module { + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } + func.func private @composite_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + return %arg0 : tensor + } + } +)mlir"; + +constexpr absl::string_view kModulePartitionedCall = R"mlir( + module { + func.func @main(%arg0: tensor<2x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_fn_1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@main"("MatMul") at "QuantizationUnit(\12\06MatMul\1a\07main)")) + return %0 : tensor<2x2xf32> + } + func.func private @composite_fn_1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + } +)mlir"; + TEST_F(AttrsAndConstraintsTest, HasStaticShapeSucceedsWithStaticShapes) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Value dot_general_result = FindOperationOfType(main_fn)->getResult(0); EXPECT_TRUE(HasStaticShape(dot_general_result)); @@ -66,7 +117,9 @@ TEST_F(AttrsAndConstraintsTest, HasStaticShapeSucceedsWithStaticShapes) { TEST_F(AttrsAndConstraintsTest, HasStaticShapeFailsWithDynamicShapes) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleDynamic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Value dot_general_result = FindOperationOfType(main_fn)->getResult(0); EXPECT_FALSE(HasStaticShape(dot_general_result)); @@ -76,7 +129,9 @@ TEST_F(AttrsAndConstraintsTest, HasStaticShapeFailsWithDynamicShapes) { TEST_F(AttrsAndConstraintsTest, TryCastSucceeds) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); EXPECT_TRUE(succeeded( TryCast(dot_general_op, /*name=*/"dot_general_op"))); @@ -84,7 +139,9 @@ TEST_F(AttrsAndConstraintsTest, TryCastSucceeds) { TEST_F(AttrsAndConstraintsTest, TryCastFailsOnWrongType) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); EXPECT_TRUE( failed(TryCast(dot_general_op, /*name=*/"dot_general_op"))); @@ -92,7 +149,9 @@ TEST_F(AttrsAndConstraintsTest, TryCastFailsOnWrongType) { TEST_F(AttrsAndConstraintsTest, TryCastFailsOnNullPtr) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStatic); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* op_nullptr = FindOperationOfType(main_fn)->getNextNode()->getNextNode(); // getNextNode() returns a nullptr if at the very last node. @@ -101,5 +160,74 @@ TEST_F(AttrsAndConstraintsTest, TryCastFailsOnNullPtr) { EXPECT_TRUE(failed(TryCast(nullptr, /*name=*/"nullptr"))); } +TEST_F(AttrsAndConstraintsTest, I64ValueInI32RangeAreCastedCorrectly) { + EXPECT_TRUE(succeeded(CastI64ToI32(llvm::minIntN(32)))); + EXPECT_TRUE(succeeded(CastI64ToI32(llvm::maxIntN(32)))); +} + +TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ValueOutOfI32Range) { + EXPECT_TRUE(failed(CastI64ToI32(llvm::minIntN(32) - 10))); + EXPECT_TRUE(failed(CastI64ToI32(llvm::maxIntN(32) + 10))); +} + +TEST_F(AttrsAndConstraintsTest, I64ArrayInI32RangeAreCastedCorrectly) { + const SmallVector array_i64 = {llvm::minIntN(32), -2, -1, 0, 1, 2, + llvm::maxIntN(32)}; + + FailureOr> array_i32 = CastI64ArrayToI32(array_i64); + EXPECT_TRUE(succeeded(array_i32)); + EXPECT_THAT( + *array_i32, + ElementsAreArray({static_cast(llvm::minIntN(32)), -2, -1, 0, 1, + 2, static_cast(llvm::maxIntN(32))})); +} + +TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ArrayUnderI32Range) { + const int64_t under_min_i32 = -2147483658; + ArrayRef array_i64{under_min_i32}; + EXPECT_EQ(under_min_i32, llvm::minIntN(32) - 10); + EXPECT_TRUE(failed(CastI64ArrayToI32(array_i64))); +} + +TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ArrayAboveI32Range) { + const int64_t below_max_i32 = 2147483657; + ArrayRef array_i64{below_max_i32}; + EXPECT_EQ(below_max_i32, llvm::maxIntN(32) + 10); + EXPECT_TRUE(failed(CastI64ArrayToI32(array_i64))); +} + +TEST_F(AttrsAndConstraintsTest, FindUserOfDifferentTypes) { + OwningOpRef module_op_ref = + ParseModuleOpString(kModuleMultipleUses); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + + Operation* dot_general_op = FindOperationOfType(main_fn); + ASSERT_NE(FindUserOfType(dot_general_op), nullptr); + ASSERT_NE(FindUserOfType(dot_general_op), nullptr); + ASSERT_NE(FindUserOfType<>(dot_general_op), nullptr); + ASSERT_EQ(FindUserOfType(dot_general_op), nullptr); +} + +TEST_F(AttrsAndConstraintsTest, CallGetFuncAttr) { + OwningOpRef xla_module_op_ref = + ParseModuleOpString(kModuleXlaCallModule); + func::FuncOp xml_main_fn = FindMainFuncOp(*xla_module_op_ref); + Operation* xla_op = FindOperationOfType(xml_main_fn); + auto xla_call_op = dyn_cast_or_null(*xla_op); + FlatSymbolRefAttr xla_call_op_attr = GetFuncAttr(xla_call_op); + EXPECT_EQ(xla_call_op_attr.getValue(), "composite_fn_1"); + + OwningOpRef partitioned_module_op_ref = + ParseModuleOpString(kModulePartitionedCall); + func::FuncOp partitioned_main_fn = FindMainFuncOp(*partitioned_module_op_ref); + Operation* partitioned_op = + FindOperationOfType(partitioned_main_fn); + auto partitioned_call_op = + dyn_cast_or_null(*partitioned_op); + FlatSymbolRefAttr partitioned_call_op_attr = GetFuncAttr(partitioned_call_op); + EXPECT_EQ(partitioned_call_op_attr.getValue(), "composite_fn_1"); +} + } // namespace } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/func.cc b/tensorflow/compiler/mlir/quantization/common/func.cc new file mode 100644 index 00000000000000..5849289e6d7ebd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/func.cc @@ -0,0 +1,55 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/func.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/cc/saved_model/signature_constants.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +namespace mlir::quant { +namespace { + +using ::tensorflow::kDefaultServingSignatureDefKey; +using ::tensorflow::kImportModelDefaultGraphFuncName; + +// Returns true iff the function's symbol is public. +bool IsPublicFuncOp(func::FuncOp func_op) { + return SymbolTable::getSymbolVisibility(&*func_op) == + SymbolTable::Visibility::Public; +} + +} // namespace + +func::FuncOp FindMainFuncOp(ModuleOp module_op) { + if (const auto main_func_op = module_op.lookupSymbol( + kImportModelDefaultGraphFuncName); + main_func_op != nullptr && IsPublicFuncOp(main_func_op)) { + return main_func_op; + } + + if (const auto serving_default_func_op = + module_op.lookupSymbol(kDefaultServingSignatureDefKey); + serving_default_func_op != nullptr && + IsPublicFuncOp(serving_default_func_op)) { + return serving_default_func_op; + } + + return nullptr; +} + +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/func.h b/tensorflow/compiler/mlir/quantization/common/func.h new file mode 100644 index 00000000000000..ade7bcfc71027b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/func.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir::quant { + +// Returns a public `func::FuncOp` in `module_op` whose name matches either +// `main` or `serving_default`. If `func::FuncOps` with both names exist, the +// function with name "main" takes precedence. Returns null if no such a +// function exists. +func::FuncOp FindMainFuncOp(ModuleOp module_op); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/func_test.cc b/tensorflow/compiler/mlir/quantization/common/func_test.cc new file mode 100644 index 00000000000000..8555da63b71feb --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/func_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/func.h" + +#include +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" + +namespace mlir::quant { +namespace { + +using ::testing::IsNull; +using ::testing::NotNull; + +class FindMainFuncOpTest : public QuantizationTestBase {}; + +TEST_F(FindMainFuncOpTest, ReturnsMainFuncOp) { + constexpr absl::string_view kModuleWithMainFunc = R"mlir( + module { + func.func @main() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = ParseModuleOpString(kModuleWithMainFunc); + EXPECT_THAT(*module_op, NotNull()); + + func::FuncOp main_func_op = FindMainFuncOp(*module_op); + EXPECT_THAT(main_func_op, NotNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsNullWhenMainFuncOpIsPrivate) { + constexpr absl::string_view kModuleWithPrivateMainFunc = R"mlir( + module { + func.func private @main() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = + ParseModuleOpString(kModuleWithPrivateMainFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), IsNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsServingDefaultFuncOp) { + constexpr absl::string_view kModuleWithServingDefaultFunc = R"mlir( + module { + func.func @serving_default() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = + ParseModuleOpString(kModuleWithServingDefaultFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), NotNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsNullWhenServingDefaultFuncOpIsPrivate) { + constexpr absl::string_view kModuleWithPrivateServingDefaultFunc = R"mlir( + module { + func.func private @serving_default() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = + ParseModuleOpString(kModuleWithPrivateServingDefaultFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), IsNull()); +} + +TEST_F(FindMainFuncOpTest, ReturnsNullWhenMainFuncNotFound) { + constexpr absl::string_view kModuleWithNoMainFunc = R"mlir( + module { + func.func @foo() -> () { + return + } + } + )mlir"; + + OwningOpRef module_op = ParseModuleOpString(kModuleWithNoMainFunc); + EXPECT_THAT(*module_op, NotNull()); + + EXPECT_THAT(FindMainFuncOp(*module_op), IsNull()); +} + +} // namespace +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc index 36ed212afb6259..2a1b10cc3163a2 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" +#include #include #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" @@ -28,13 +29,14 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir::quant { namespace { -using ::mlir::quant::QuantizationTestBase; +using ::testing::NotNull; class LiftAsFunctionCallTest : public QuantizationTestBase {}; @@ -49,8 +51,10 @@ constexpr absl::string_view kModuleLifted = R"mlir( TEST_F(LiftAsFunctionCallTest, LiftedFunctionSucceeds) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleLifted); - func::FuncOp composite_dot_general_fn = - GetFunctionFromModule(*module_op_ref, "composite_dot_general_fn_1"); + auto composite_dot_general_fn = + module_op_ref->lookupSymbol("composite_dot_general_fn_1"); + ASSERT_THAT(composite_dot_general_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType( composite_dot_general_fn); @@ -59,7 +63,7 @@ TEST_F(LiftAsFunctionCallTest, LiftedFunctionSucceeds) { constexpr absl::string_view kModuleStableHlo = R"mlir( module { - func.func private @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + func.func @main(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } @@ -68,15 +72,18 @@ constexpr absl::string_view kModuleStableHlo = R"mlir( TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStableHlo); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); const SmallVector& attributes = { - builder_.getNamedAttr("precision_config", - builder_.getArrayAttr(SmallVector( - 1, stablehlo::PrecisionAttr::get( - &ctx_, stablehlo::Precision::DEFAULT)))), + builder_.getNamedAttr( + "precision_config", + builder_.getArrayAttr(SmallVector( + 1, mlir::stablehlo::PrecisionAttr::get( + ctx_.get(), mlir::stablehlo::Precision::DEFAULT)))), }; Operation* lifted_op = LiftAsFunctionCall(builder_, dot_general_op->getLoc(), @@ -99,13 +106,15 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { EXPECT_EQ( lifted_dot_general_op->getAttr("precision_config").cast(), builder_.getArrayAttr(SmallVector( - 1, stablehlo::PrecisionAttr::get(&ctx_, - stablehlo::Precision::DEFAULT)))); + 1, mlir::stablehlo::PrecisionAttr::get( + ctx_.get(), mlir::stablehlo::Precision::DEFAULT)))); } TEST_F(LiftAsFunctionCallTest, FunctionNoAttrLiftedAsXlaCallModuleOp) { OwningOpRef module_op_ref = ParseModuleOpString(kModuleStableHlo); - func::FuncOp main_fn = GetFunctionFromModule(*module_op_ref, "main"); + func::FuncOp main_fn = FindMainFuncOp(*module_op_ref); + ASSERT_THAT(main_fn, NotNull()); + Operation* dot_general_op = FindOperationOfType(main_fn); Operation* lifted_op = diff --git a/tensorflow/compiler/mlir/quantization/common/test_base.h b/tensorflow/compiler/mlir/quantization/common/test_base.h index 5402769bf04c2a..1068a42a615027 100644 --- a/tensorflow/compiler/mlir/quantization/common/test_base.h +++ b/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ +#include + #include #include "absl/strings/string_view.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -25,10 +27,10 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -40,31 +42,26 @@ using ::testing::Test; class QuantizationTestBase : public Test { protected: - QuantizationTestBase() { - ctx_.loadDialect(); + QuantizationTestBase() + : ctx_(stablehlo::CreateMlirContextForQuantization()), + builder_(ctx_.get()) { + ctx_->loadDialect(); } // Parses `module_op_str` to create a `ModuleOp`. Checks whether the created // module op is valid. OwningOpRef ParseModuleOpString( const absl::string_view module_op_str) { - auto module_op_ref = parseSourceString(module_op_str, &ctx_); + auto module_op_ref = parseSourceString(module_op_str, ctx_.get()); EXPECT_TRUE(module_op_ref); return module_op_ref; } - // Gets the function with the given name from the module. - func::FuncOp GetFunctionFromModule(ModuleOp module, - absl::string_view function_name) { - SymbolTable symbol_table(module); - return symbol_table.lookup(function_name); - } - // Returns the first operation with the given type in the function. template OpType FindOperationOfType(func::FuncOp function) { @@ -74,8 +71,8 @@ class QuantizationTestBase : public Test { return nullptr; } - mlir::MLIRContext ctx_{}; - OpBuilder builder_{&ctx_}; + std::unique_ptr ctx_; + OpBuilder builder_; }; } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc similarity index 72% rename from tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc rename to tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc index eecc96b04be9eb..5cee0692080b2a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc @@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #include +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -33,12 +35,14 @@ namespace quant { UniformQuantizedType CreateI8F32UniformQuantizedType(const Location loc, MLIRContext& context, const double scale, - const int64_t zero_point) { + const int64_t zero_point, + const bool narrow_range) { return UniformQuantizedType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/8), /*expressedType=*/FloatType::getF32(&context), scale, zero_point, - /*storageTypeMin=*/llvm::minIntN(8), /*storageTypeMax=*/llvm::maxIntN(8)); + /*storageTypeMin=*/llvm::minIntN(8) + (narrow_range ? 1 : 0), + /*storageTypeMax=*/llvm::maxIntN(8)); } UniformQuantizedType CreateI32F32UniformQuantizedType( @@ -54,16 +58,30 @@ UniformQuantizedType CreateI32F32UniformQuantizedType( UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( const Location loc, MLIRContext& context, const ArrayRef scales, - const ArrayRef zero_points, const int quantization_dimension) { + const ArrayRef zero_points, const int quantization_dimension, + const bool narrow_range) { return UniformQuantizedPerAxisType::getChecked( loc, /*flags=*/QuantizationFlags::Signed, /*storageType=*/IntegerType::get(&context, /*width=*/8), /*expressedType=*/FloatType::getF32(&context), SmallVector(scales), SmallVector(zero_points), - quantization_dimension, /*storageTypeMin=*/llvm::minIntN(8), + quantization_dimension, + /*storageTypeMin=*/llvm::minIntN(8) + (narrow_range ? 1 : 0), /*storageTypeMax=*/llvm::maxIntN(8)); } +UniformQuantizedPerAxisType CreateI32F32UniformQuantizedPerAxisType( + const Location loc, MLIRContext& context, const ArrayRef scales, + const ArrayRef zero_points, const int quantization_dimension) { + return UniformQuantizedPerAxisType::getChecked( + loc, /*flags=*/QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/32), + /*expressedType=*/FloatType::getF32(&context), + SmallVector(scales), SmallVector(zero_points), + quantization_dimension, /*storageTypeMin=*/llvm::minIntN(32), + /*storageTypeMax=*/llvm::maxIntN(32)); +} + bool IsStorageTypeI8(const QuantizedType quantized_type) { const Type storage_type = quantized_type.getStorageType(); return storage_type.isInteger(/*width=*/8); @@ -151,6 +169,30 @@ bool IsI32F32UniformQuantizedType(const Type type) { return true; } +bool IsI32F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + type.dyn_cast_or_null(); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + // Determines whether the storage type of a quantized type is supported by // `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { @@ -165,5 +207,18 @@ bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { return false; } +bool IsQuantizedTensorType(Type type) { + if (!type.isa()) { + return false; + } + Type element_type = type.cast().getElementType(); + return element_type.isa(); +} + +bool IsOpFullyQuantized(Operation* op) { + return llvm::all_of(op->getOperandTypes(), IsQuantizedTensorType) && + llvm::all_of(op->getResultTypes(), IsQuantizedTensorType); +} + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h similarity index 69% rename from tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h rename to tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h index d04dc5a5761b8f..f1c94302d816b3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_UNIFORM_QUANTIZED_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_UNIFORM_QUANTIZED_TYPES_H_ #include @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -31,17 +32,20 @@ namespace quant { // values. The produced type has f32 as its expressed type and i8 as its // storage type. The available values use the full range of the storage value, // i.e. [-128, 127]. Assumes asymmetric quantization, meaning the zero point -// values can be non-zero values. +// value can be a non-zero value. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, MLIRContext& context, double scale, - int64_t zero_point); + int64_t zero_point, + bool narrow_range = false); // Creates a `UniformQuantizedType` with the given `scale` and `zero_point` // values. The produced type has f32 as its expressed type and i32 as its // storage type. The available values use the full range of the storage value. -// Assumes asymmetric quantization, meaning the zero point values can be -// non-zero values. +// Assumes asymmetric quantization, meaning the zero point value can be +// a non-zero value. UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc, MLIRContext& context, double scale, @@ -52,7 +56,19 @@ UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc, // i8 as its storage type. The available values use the full range of the // storage value, i.e. [-128, 127]. Assumes asymmetric quantization, meaning the // zero point values can be non-zero values. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension, + bool narrow_range = false); + +// Creates a `UniformQuantizedPerAxisType` with the given `scales` and +// `zero_points` values. The produced type has f32 as its expressed type and +// i32 as its storage type. The available values use the full range of the +// storage value. Assumes asymmetric quantization, meaning the +// zero point values can be non-zero values. +UniformQuantizedPerAxisType CreateI32F32UniformQuantizedPerAxisType( Location loc, MLIRContext& context, ArrayRef scales, ArrayRef zero_points, int quantization_dimension); @@ -74,11 +90,21 @@ bool IsI8F32UniformQuantizedPerAxisType(Type type); // 32-bit integer and expressed type is f32. bool IsI32F32UniformQuantizedType(Type type); +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedPerAxisType(Type type); + // Determines whether the storage type of a quantized type is supported by // `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type); +// Returns true if a type is quantized tensor type. +bool IsQuantizedTensorType(Type type); + +// Returns true if all operands and results are quantized. +bool IsOpFullyQuantized(Operation* op); + } // namespace quant } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_ +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_UNIFORM_QUANTIZED_TYPES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc similarity index 56% rename from tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc rename to tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc index f33b322cfbd9e4..10499526873f2c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #include #include @@ -31,6 +31,7 @@ namespace quant { namespace { using ::testing::ElementsAreArray; +using ::testing::IsNull; using ::testing::NotNull; using ::testing::Test; @@ -47,7 +48,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, I8StorageTypeSucceeds) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); - + // Storage type of `i8` is currently verifiable as `unsigned` in `Types.cpp`. EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } @@ -76,6 +77,15 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, StorageTypeMinMaxEqualToI8MinMax) { EXPECT_EQ(quantized_type.getStorageTypeMax(), 127); } +TEST_F(CreateI8F32UniformQuantizedTypeTest, StorageTypeMinMaxNarrowRange) { + const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType( + UnknownLoc::get(&ctx_), ctx_, + /*scale=*/1.0, /*zero_point=*/0, /*narrow_range=*/true); + + EXPECT_EQ(quantized_type.getStorageTypeMin(), -127); + EXPECT_EQ(quantized_type.getStorageTypeMax(), 127); +} + TEST_F(CreateI8F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, @@ -99,6 +109,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, I32StorageTypeSucceeds) { CreateI32F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); + // Storage type of `i32` is currently verifiable as `unsigned` in `Types.cpp`. EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32)); } @@ -156,6 +167,7 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, I8StorageTypeSucceeds) { /*zero_points=*/SmallVector{0, 0}, /*quantization_dimension=*/0); + // Storage type of `i8` is currently verifiable as `unsigned` in `Types.cpp`. EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); } @@ -195,6 +207,19 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, EXPECT_EQ(quantized_type.getStorageTypeMax(), 127); } +TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, + StorageTypeMinMaxNarrowRange) { + const UniformQuantizedPerAxisType quantized_type = + CreateI8F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/0, /*narrow_range=*/true); + + EXPECT_EQ(quantized_type.getStorageTypeMin(), -127); + EXPECT_EQ(quantized_type.getStorageTypeMax(), 127); +} + TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, HasQuantizationDimensionProperlySet) { const UniformQuantizedPerAxisType quantized_type = @@ -220,62 +245,139 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); } +class CreateI32F32UniformQuantizedPerAxisTypeTest : public Test { + protected: + CreateI32F32UniformQuantizedPerAxisTypeTest() : ctx_() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; +}; + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, I32StorageTypeSucceeds) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/0); + + // Storage type of `i32` is currently verifiable as `unsigned` in `Types.cpp`. + EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(32)); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, F32ExpressedTypeSucceeds) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/0); + + EXPECT_TRUE(quantized_type.getExpressedType().isF32()); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, + StorageTypeMinMaxEqualToI32MinMax) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/0); + + EXPECT_EQ(quantized_type.getStorageTypeMin(), + std::numeric_limits::min()); + EXPECT_EQ(quantized_type.getStorageTypeMax(), + std::numeric_limits::max()); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, + HasQuantizationDimensionProperlySet) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{1.0, 1.0}, + /*zero_points=*/SmallVector{0, 0}, + /*quantization_dimension=*/3); + + EXPECT_EQ(quantized_type.getQuantizedDimension(), 3); +} + +TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, + HasScaleAndZeroPointProperlySet) { + const UniformQuantizedPerAxisType quantized_type = + CreateI32F32UniformQuantizedPerAxisType( + UnknownLoc::get(&ctx_), ctx_, + /*scales=*/SmallVector{8.0, 9.0}, + /*zero_points=*/SmallVector{98, 99}, + /*quantization_dimension=*/0); + + EXPECT_THAT(quantized_type.getScales(), ElementsAreArray({8.0, 9.0})); + EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); +} + class IsI8F32UniformQuantizedTypeTest : public Test { protected: - IsI8F32UniformQuantizedTypeTest() { + IsI8F32UniformQuantizedTypeTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsI8F32UniformQuantizedTypeTest, I8F32UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsI8F32UniformQuantizedType(qi8_type)); } TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_THAT(qi8_type.dyn_cast_or_null(), NotNull()); } TEST_F(IsI8F32UniformQuantizedTypeTest, StorageTypeI8Succeeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsStorageTypeI8(qi8_type)); } TEST_F(IsI8F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsExpressedTypeF32(qi8_type)); } class IsI8F32UniformQuantizedPerAxisTypeTest : public Test { protected: - IsI8F32UniformQuantizedPerAxisTypeTest() { + IsI8F32UniformQuantizedPerAxisTypeTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, I8F32UniformQuantizedPerAxisTypeSucceeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_TRUE(IsI8F32UniformQuantizedPerAxisType(qi8_per_axis_type)); EXPECT_FALSE(IsI8F32UniformQuantizedType(qi8_per_axis_type)); } @@ -283,10 +385,11 @@ TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_THAT(qi8_per_axis_type.dyn_cast_or_null(), NotNull()); } @@ -294,96 +397,187 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_TRUE(IsStorageTypeI8(qi8_per_axis_type)); } TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedPerAxisType qi8_per_axis_type = quant::UniformQuantizedPerAxisType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), /*scales=*/{1.0}, - /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); EXPECT_TRUE(IsExpressedTypeF32(qi8_per_axis_type)); } class IsI32F32UniformQuantizedTypeTest : public Test { protected: - IsI32F32UniformQuantizedTypeTest() { + IsI32F32UniformQuantizedTypeTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsI32F32UniformQuantizedTypeTest, I32F32UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); } TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); EXPECT_THAT(qi32_type.dyn_cast_or_null(), NotNull()); } TEST_F(IsI32F32UniformQuantizedTypeTest, StorageTypeI32Succeeds) { const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); EXPECT_TRUE(IsStorageTypeI32(qi32_type)); } TEST_F(IsI32F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { const UniformQuantizedType qi32_per_axis_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, + /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); +} + +class IsI32F32UniformQuantizedPerAxisTypeTest : public Test { + protected: + IsI32F32UniformQuantizedPerAxisTypeTest() : builder_(&ctx_) { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_; +}; + +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, + I32F32UniformQuantizedPerAxisTypeSucceeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); + EXPECT_TRUE(IsI32F32UniformQuantizedPerAxisType(qi32_per_axis_type)); + EXPECT_FALSE(IsI32F32UniformQuantizedType(qi32_per_axis_type)); +} + +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, + I8F32UniformQuantizedTypeFails) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/-128, + /*storageTypeMax=*/127); + EXPECT_FALSE(IsI32F32UniformQuantizedPerAxisType(qi8_type)); + EXPECT_FALSE(IsStorageTypeI32(qi8_type)); + EXPECT_THAT(qi8_type.dyn_cast_or_null(), + IsNull()); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); + + EXPECT_THAT( + qi32_per_axis_type.dyn_cast_or_null(), + NotNull()); +} + +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); + + EXPECT_TRUE(IsStorageTypeI32(qi32_per_axis_type)); +} + +TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { + const UniformQuantizedPerAxisType qi32_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/QuantizationFlags::Signed, builder_.getI32Type(), + builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, + /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); } class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public Test { protected: - IsSupportedByTfliteQuantizeOrDequantizeOpsTest() { + IsSupportedByTfliteQuantizeOrDequantizeOpsTest() : builder_(&ctx_) { ctx_.loadDialect(); } MLIRContext ctx_; - OpBuilder builder_{&ctx_}; + OpBuilder builder_; }; TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI8Succeeds) { auto qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/true), - builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( dyn_cast_or_null(qi8_type.getStorageType()))); } TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeI16Succeeds) { auto qi16_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getIntegerType(16, /*isSigned=*/true), - builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( dyn_cast_or_null(qi16_type.getStorageType()))); } TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, StorageTypeUI8Succeeds) { auto qi8_type = quant::UniformQuantizedType::get( - /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/false), - builder_.getF32Type(), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), + builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( dyn_cast_or_null(qi8_type.getStorageType()))); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 7d2d1f87ef1831..439e8542cf7d0e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -52,7 +52,6 @@ cc_library( "passes/lift_quantizable_spots_as_functions_fusion.inc", "passes/lift_quantizable_spots_as_functions_simple.inc", "passes/optimize_graph.cc", - "passes/populate_shape.cc", "passes/post_quantize.cc", "passes/prepare_quantize.cc", "passes/quantize.cc", @@ -74,16 +73,17 @@ cc_library( ":lift_quantizable_spots_as_functions_fusion_inc_gen", ":lift_quantizable_spots_as_functions_simple_inc_gen", ":optimize_graph_inc_gen", + ":quantization_config_proto_cc", ":quantization_options_proto_cc", ":quantization_patterns", ":stablehlo_passes_inc_gen", ":stablehlo_type_utils", - ":uniform_quantized_types", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", @@ -109,6 +109,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", @@ -124,6 +125,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/mlir_hlo", @@ -144,15 +146,14 @@ cc_library( ], compatible_with = get_compatible_with_portable(), deps = [ - ":uniform_quantized_types", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", "@com_google_absl//absl/algorithm:container", @@ -474,28 +475,38 @@ gentbl_cc_library( cc_library( name = "test_passes", srcs = [ + "passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc", "passes/testing/test_post_calibration_component.cc", "passes/testing/test_pre_calibration_component.cc", + "passes/testing/test_tf_to_stablehlo_pass.cc", ], hdrs = [ "passes/testing/passes.h", ], compatible_with = get_compatible_with_portable(), deps = [ + ":passes", ":quantization_config_proto_cc", ":stablehlo_test_passes_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:post_calibration", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pre_calibration", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", "@local_xla//xla/mlir_hlo", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", @@ -741,28 +752,3 @@ tf_cc_binary( "@stablehlo//:vhlo_ops", ], ) - -cc_library( - name = "uniform_quantized_types", - srcs = ["uniform_quantized_types.cc"], - hdrs = ["uniform_quantized_types.h"], - compatible_with = get_compatible_with_portable(), - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:Support", - ], -) - -tf_cc_test( - name = "uniform_quantized_types_test", - srcs = ["uniform_quantized_types_test.cc"], - deps = [ - ":uniform_quantized_types", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:Support", - ], -) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index d02f2651cfd3e2..10d4b020166552 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -7,8 +7,9 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", - "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", + # For TFLite Converter integration. + "//tensorflow/compiler/mlir/lite:__subpackages__", + "//tensorflow/compiler/mlir/quantization:__subpackages__", ], licenses = ["notice"], ) @@ -32,6 +33,29 @@ cc_library( ], ) +# OSS: This is a header-only target. Do NOT directly depend on `config_impl` unless it is necessary +# (e.g. undefined symbol error), to avoid ODR violation. +cc_library( + name = "config", + hdrs = ["config.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + ], +) + +# OSS: This is a impl target corresponding to `config`. Do NOT directly depend on `config_impl` +# unless it is necessary (e.g. undefined symbol error), to avoid ODR violation. +cc_library( + name = "config_impl", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + ], +) + cc_library( name = "io", srcs = ["io.cc"], @@ -187,6 +211,9 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", @@ -210,7 +237,6 @@ cc_library( deps = [ ":component", ":pass_pipeline", - "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -219,7 +245,6 @@ cc_library( "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:errors", @@ -264,7 +289,6 @@ cc_library( deps = [ ":component", ":pass_pipeline", - "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", @@ -272,7 +296,6 @@ cc_library( "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:errors", @@ -311,6 +334,11 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow:__pkg__", + "//tensorflow/compiler/mlir/lite:__pkg__", # For tf_tfl_translate binary. + # For odml_to_stablehlo binary. + "//tensorflow/compiler/mlir/lite/stablehlo:__pkg__", + # For StableHLO Quantizer adapter functionalities within TFLite. Testonly. + "//tensorflow/compiler/mlir/lite/quantization/stablehlo:__pkg__", "//tensorflow/python:__pkg__", ], deps = [ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD index 427f61996b2881..b31b9d5321a1ab 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD @@ -67,6 +67,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":assign_ids", + ":representative_dataset", ":statistics", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", @@ -104,3 +105,33 @@ cc_library( "@local_tsl//tsl/platform:statusor", ], ) + +cc_library( + name = "representative_dataset", + srcs = ["representative_dataset.cc"], + hdrs = ["representative_dataset.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "representative_dataset_test", + srcs = ["representative_dataset_test.cc"], + deps = [ + ":representative_dataset", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc index 98131a96bb3aba..4f6478efebf2f5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/assign_ids.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" @@ -60,12 +61,13 @@ limitations under the License. #include "tsl/platform/statusor.h" namespace mlir::quant::stablehlo { - namespace { using ::stablehlo::quantization::AddCalibrationStatistics; using ::stablehlo::quantization::AssignIdsToCustomAggregatorOps; +using ::stablehlo::quantization::CreateRepresentativeDatasetFileMap; using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::RepresentativeDatasetConfig; using ::stablehlo::quantization::io::CreateTmpDir; using ::stablehlo::quantization::io::GetLocalTmpFileName; using ::tensorflow::AssetFileDef; @@ -77,7 +79,6 @@ using ::tensorflow::quantization::CalibrationOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PreprocessAndFreezeGraph; using ::tensorflow::quantization::PyFunctionLibrary; -using ::tensorflow::quantization::RepresentativeDatasetFile; using ::tensorflow::quantization::RunPasses; using ::tensorflow::quantization::UnfreezeConstantsAndSaveVariables; @@ -132,7 +133,8 @@ absl::StatusOr> RunExportPasses( } if (absl::Status pass_run_status = RunPasses( - /*name=*/export_opts.debug_name, + /*name=*/ + export_opts.debug_name, /*add_passes_func=*/ [dup_constants = export_opts.duplicate_shape_determining_constants]( PassManager& pm) { AddExportPasses(pm, dup_constants); }, @@ -160,8 +162,6 @@ CalibrationComponent::CalibrationComponent( std::unordered_set tags, absl::flat_hash_map signature_def_map, std::vector signature_keys, - absl::flat_hash_map - representative_dataset_file_map, const CalibrationOptions& calibration_options) : ctx_(ABSL_DIE_IF_NULL(ctx)), // Crash OK py_function_lib_(ABSL_DIE_IF_NULL(py_function_lib)), // Crash OK @@ -170,7 +170,6 @@ CalibrationComponent::CalibrationComponent( tags_(std::move(tags)), signature_def_map_(std::move(signature_def_map)), signature_keys_(std::move(signature_keys)), - representative_dataset_file_map_(representative_dataset_file_map), calibration_options_(calibration_options) {} absl::StatusOr CalibrationComponent::ExportToSavedModel( @@ -247,13 +246,23 @@ absl::StatusOr CalibrationComponent::Run( ExportedModel exported_model, ExportToSavedModel(module_op, precalibrated_saved_model_dir)); + // Translates `RepresentativeDatasetConfig`s to signature key -> + // `RepresentativeDatasetFile` mapping. + const auto dataset_configs = + config.static_range_ptq_preset().representative_datasets(); + const std::vector dataset_config_vector( + dataset_configs.begin(), dataset_configs.end()); + TF_ASSIGN_OR_RETURN( + const auto representative_dataset_file_map, + CreateRepresentativeDatasetFileMap(dataset_config_vector)); + // Runs calibration on the exported model. The statistics will be stored in a // separate singleton object `CalibratorSingleton` and are directly added to // `exported_model` without re-importing it. py_function_lib_->RunCalibration(precalibrated_saved_model_dir, signature_keys_, tags_, calibration_options_, /*force_graph_mode_calibration=*/true, - representative_dataset_file_map_); + representative_dataset_file_map); if (absl::Status status = AddCalibrationStatistics(*exported_model.mutable_graph_def(), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h index f049ab45c7561f..1d2cb94b93a0bc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h @@ -66,9 +66,6 @@ class CalibrationComponent : public Component { absl::flat_hash_map signature_def_map, std::vector signature_keys, - absl::flat_hash_map - representative_dataset_file_map, const tensorflow::quantization::CalibrationOptions& calibration_options); // Runs calibration on `module_op` and returns a calibrated ModuleOp with @@ -113,13 +110,6 @@ class CalibrationComponent : public Component { // Signature keys to identify the functions to load & quantize. const std::vector signature_keys_; - // Map from signature key to the representative dataset file. The keys should - // match `signature_keys_`. The representative datasets will be fed to the - // pre-calibrated graph to collect statistics. - const absl::flat_hash_map - representative_dataset_file_map_; - // Configures the calibration behavior. const tensorflow::quantization::CalibrationOptions calibration_options_; }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.cc new file mode 100644 index 00000000000000..4c8ab1df2ecc03 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.cc @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace stablehlo::quantization { + +using ::tensorflow::quantization::RepresentativeDatasetFile; + +absl::StatusOr> +CreateRepresentativeDatasetFileMap(absl::Span + representative_dataset_configs) { + absl::flat_hash_map + repr_dataset_file_map{}; + + for (const RepresentativeDatasetConfig& dataset_config : + representative_dataset_configs) { + RepresentativeDatasetFile repr_dataset_file; + + repr_dataset_file.set_tfrecord_file_path(dataset_config.tf_record().path()); + // If the signature_key has not been explicitly specified, use the default + // value of "serving_default". + const std::string signature_key = dataset_config.has_signature_key() + ? dataset_config.signature_key() + : "serving_default"; + if (repr_dataset_file_map.contains(signature_key)) { + return absl::InvalidArgumentError( + absl::StrCat("RepresentativeDatasetConfig should not contain " + "duplicate signature key: ", + signature_key)); + } + repr_dataset_file_map[signature_key] = std::move(repr_dataset_file); + } + + return repr_dataset_file_map; +} + +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h new file mode 100644 index 00000000000000..33357630aa1098 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_REPRESENTATIVE_DATASET_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_REPRESENTATIVE_DATASET_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace stablehlo::quantization { + +// Translates a set of `RepresentativeDatsetConfig` to signature key -> +// `RepresentativeDatasetFile` mapping. This is useful when using +// `RepresentativeDatasetConfig`s at places that accept the legacy +// `RepresentativeDatasetFile` mapping. +// Returns a non-OK status when there is a duplicate signature key among +// `representative_dataset_configs`. +absl::StatusOr> +CreateRepresentativeDatasetFileMap(absl::Span + representative_dataset_configs); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_REPRESENTATIVE_DATASET_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc new file mode 100644 index 00000000000000..aaedfc72086f07 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h" + +#include +#include + +#include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tsl/platform/status_matchers.h" + +namespace mlir::quant::stablehlo { +namespace { + +using ::stablehlo::quantization::RepresentativeDatasetConfig; +using ::tensorflow::quantization::RepresentativeDatasetFile; +using ::testing::Contains; +using ::testing::HasSubstr; +using ::testing::Key; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +TEST(CreateRepresentativeDatasetFileMapTest, + ConfigWithoutExplicitSignatureKeyMappedToServingDefault) { + std::vector representative_dataset_configs; + + RepresentativeDatasetConfig config{}; + *(config.mutable_tf_record()->mutable_path()) = "test_path"; + representative_dataset_configs.push_back(config); + + const absl::StatusOr< + absl::flat_hash_map> + representative_dataset_file_map = + CreateRepresentativeDatasetFileMap(representative_dataset_configs); + + ASSERT_THAT(representative_dataset_file_map, IsOk()); + ASSERT_THAT(*representative_dataset_file_map, SizeIs(1)); + EXPECT_THAT(*representative_dataset_file_map, + Contains(Key("serving_default"))); + EXPECT_THAT(representative_dataset_file_map->at("serving_default") + .tfrecord_file_path(), + StrEq("test_path")); +} + +TEST(CreateRepresentativeDatasetFileMapTest, ConfigWithExplicitSignatureKey) { + std::vector representative_dataset_configs; + + RepresentativeDatasetConfig config{}; + config.set_signature_key("test_signature_key"); + *(config.mutable_tf_record()->mutable_path()) = "test_path"; + representative_dataset_configs.push_back(config); + + const absl::StatusOr< + absl::flat_hash_map> + representative_dataset_file_map = + CreateRepresentativeDatasetFileMap(representative_dataset_configs); + + ASSERT_THAT(representative_dataset_file_map, IsOk()); + ASSERT_THAT(*representative_dataset_file_map, SizeIs(1)); + EXPECT_THAT(*representative_dataset_file_map, + Contains(Key(StrEq("test_signature_key")))); + EXPECT_THAT(representative_dataset_file_map->at("test_signature_key") + .tfrecord_file_path(), + StrEq("test_path")); +} + +TEST(CreateRepresentativeDatasetFileMapTest, + ConfigWithDuplicateSignatureKeyReturnsInvalidArgumentError) { + std::vector representative_dataset_configs; + + RepresentativeDatasetConfig config_1{}; + config_1.set_signature_key("serving_default"); + *(config_1.mutable_tf_record()->mutable_path()) = "test_path_1"; + representative_dataset_configs.push_back(config_1); + + // Signature key is implicitly "serving_default". + RepresentativeDatasetConfig config_2{}; + *(config_2.mutable_tf_record()->mutable_path()) = "test_path_2"; + representative_dataset_configs.push_back(config_2); + + const absl::StatusOr< + absl::flat_hash_map> + representative_dataset_file_map = + CreateRepresentativeDatasetFileMap(representative_dataset_configs); + + EXPECT_THAT(representative_dataset_file_map, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("duplicate signature key: serving_default"))); +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/third_party/xla/third_party/tsl/tsl/platform/jpeg.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc similarity index 52% rename from third_party/xla/third_party/tsl/tsl/platform/jpeg.h rename to tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index a7b640db03943f..679e1f8754be9b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/jpeg.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,18 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" -#ifndef TENSORFLOW_TSL_PLATFORM_JPEG_H_ -#define TENSORFLOW_TSL_PLATFORM_JPEG_H_ +namespace stablehlo::quantization { -#include -#include -#include -#include +QuantizationConfig PopulateDefaults( + const QuantizationConfig& user_provided_config) { + QuantizationConfig config = user_provided_config; -extern "C" { -#include "jerror.h" // from @libjpeg_turbo // IWYU pragma: export -#include "jpeglib.h" // from @libjpeg_turbo // IWYU pragma: export + PipelineConfig& pipeline_config = *config.mutable_pipeline_config(); + if (!pipeline_config.has_unpack_quantized_types()) { + pipeline_config.set_unpack_quantized_types(true); + } + + return config; } -#endif // TENSORFLOW_TSL_PLATFORM_JPEG_H_ +} // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h new file mode 100644 index 00000000000000..20b9efa4a60fa0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h @@ -0,0 +1,29 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace stablehlo::quantization { + +// Returns a copy of `user_provided_config` with default values populated where +// the user did not explicitly specify. +QuantizationConfig PopulateDefaults( + const QuantizationConfig& user_provided_config); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 1b14bd56aaa44f..31c67f2d20c4ff 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -20,11 +20,45 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" namespace mlir::quant::stablehlo { +using ::stablehlo::quantization::PipelineConfig; +using ::stablehlo::quantization::QuantizationSpecs; +using ::stablehlo::quantization::StaticRangePtqPreset; +using ::tensorflow::quantization::CalibrationOptions; + +void AddPreCalibrationPasses(OpPassManager& pm, + const CalibrationOptions& calibration_options, + const QuantizationSpecs& quantization_specs) { + pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs)); + pm.addNestedPass( + CreateInsertCustomAggregationOpsPass(calibration_options)); + pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass()); + // StableHLO Quantizer currently uses TF's calibration passes. Serialize + // the StableHLO module as tf.XlaCallModule to run calibration. + AddCallModuleSerializationPasses(pm); +} + +void AddPostCalibrationPasses( + OpPassManager& pm, const PipelineConfig& pipeline_config, + const StaticRangePtqPreset& static_range_ptq_preset) { + QuantizeCompositeFunctionsPassOptions options; + options.enable_per_channel_quantized_weight_ = + static_range_ptq_preset.enable_per_channel_quantized_weight(); + pm.addNestedPass( + CreateConvertCustomAggregationOpToQuantStatsPass()); + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + if (pipeline_config.unpack_quantized_types()) { + AddStablehloQuantToIntPasses(pm); + } + AddCallModuleSerializationPasses(pm); +} + void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm) { pm.addPass(TF::CreateXlaCallModuleDeserializationPass()); pm.addPass(createRestoreFunctionNamePass()); @@ -60,6 +94,10 @@ void AddStablehloQuantToIntPasses(OpPassManager& pm) { // NOMUTANTS -- Add tests for individual passes with migration below. void AddCallModuleSerializationPasses(OpPassManager& pm) { AddShapeLegalizationPasses(pm); + // Add an inliner pass to inline quantized StableHLO functions (and others) so + // that StableHLO ops are properly grouped and converted into XlaCallModule + // ops by the ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass. + pm.addPass(createInlinerPass()); pm.addPass(createReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass()); // ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass may create // duplicate constants. Add canonicalizer to deduplicate. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h index e5272732400365..5920619bd3fb8d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h @@ -16,9 +16,26 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PASS_PIPELINE_H_ #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" namespace mlir::quant::stablehlo { +// Adds passes for static-range quantization pre-calibration. Inserts ops +// required to collect tensor statistics. +void AddPreCalibrationPasses( + OpPassManager& pm, + const ::tensorflow::quantization::CalibrationOptions& calibration_options, + const ::stablehlo::quantization::QuantizationSpecs& specs); + +// Adds passes for static-range quantization post-calibration. Utilizes tensor +// statistics collected from the calibration step and performs quantization. +void AddPostCalibrationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::PipelineConfig& pipeline_config, + const ::stablehlo::quantization::StaticRangePtqPreset& + static_range_ptq_preset); + // Deserializes StableHLO functions serialized and embedded in XlaCallModuleOps. void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc index 7c52a1c7d6790b..6f5f10b48f41f5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc @@ -17,20 +17,19 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/statusor.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU: keep #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "tsl/platform/errors.h" namespace mlir::quant::stablehlo { +using ::stablehlo::quantization::PipelineConfig; using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::StaticRangePtqPreset; using ::tensorflow::quantization::RunPasses; PostCalibrationComponent::PostCalibrationComponent( @@ -39,20 +38,20 @@ PostCalibrationComponent::PostCalibrationComponent( absl::StatusOr PostCalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { - TF_RETURN_IF_ERROR( - RunPasses(/*name=*/kName, - /*add_passes_func=*/[this](PassManager& pm) { AddPasses(pm); }, - *ctx_, module_op)); + TF_RETURN_IF_ERROR(RunPasses( + kName, /*add_passes_func=*/ + [&config, this](PassManager& pm) { + AddPostCalibrationPasses(pm, config.pipeline_config(), + config.static_range_ptq_preset()); + }, + *ctx_, module_op)); return module_op; } -void PostCalibrationComponent::AddPasses(OpPassManager& pm) const { - pm.addNestedPass( - CreateConvertCustomAggregationOpToQuantStatsPass()); - pm.addPass(createQuantizeCompositeFunctionsPass()); - pm.addPass(createOptimizeGraphPass()); - AddStablehloQuantToIntPasses(pm); - AddCallModuleSerializationPasses(pm); +void PostCalibrationComponent::AddPasses( + OpPassManager& pm, const StaticRangePtqPreset& static_range_ptq_preset, + const PipelineConfig& pipeline_config) const { + AddPostCalibrationPasses(pm, pipeline_config, static_range_ptq_preset); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h index 3ad71f05079162..3c218c9f857524 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h @@ -45,9 +45,11 @@ class PostCalibrationComponent : public Component { ModuleOp module_op, const ::stablehlo::quantization::QuantizationConfig& config) override; - // Adds MLIR passes to the pass manager. `Run` will essentially run these - // passes on the module op. - void AddPasses(OpPassManager& pm) const; + void AddPasses( + OpPassManager& pm, + const ::stablehlo::quantization::StaticRangePtqPreset& + static_range_ptq_preset, + const ::stablehlo::quantization::PipelineConfig& pipeline_config) const; private: absl::Nonnull ctx_; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc index 2f9882417420dd..f54f947990866b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc @@ -19,14 +19,11 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/statusor.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tsl/platform/errors.h" @@ -44,15 +41,9 @@ PreCalibrationComponent::PreCalibrationComponent( absl::StatusOr PreCalibrationComponent::Run( ModuleOp module_op, const QuantizationConfig& config) { TF_RETURN_IF_ERROR(RunPasses( - /*name=*/kName, /*add_passes_func=*/ - [this](PassManager& pm) { - pm.addPass(createLiftQuantizableSpotsAsFunctionsPass()); - pm.addNestedPass( - CreateInsertCustomAggregationOpsPass(calibration_options_)); - pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass()); - // StableHLO Quantizer currently uses TF's calibration passes. Serialize - // the StableHLO module as tf.XlaCallModule to run calibration. - AddCallModuleSerializationPasses(pm); + kName, /*add_passes_func=*/ + [&config, this](PassManager& pm) { + AddPreCalibrationPasses(pm, calibration_options_, config.specs()); }, *ctx_, module_op)); return module_op; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc index 69f9679f90a357..37f83e57fb678d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc @@ -68,7 +68,7 @@ class PreCalibrationComponentTest : public QuantizationTestBase {}; TEST_F(PreCalibrationComponentTest, HasCustomAggregatorOpAndQuantizableFuncForSimpleDotGeneral) { - PreCalibrationComponent component(&ctx_, CalibrationOptions()); + PreCalibrationComponent component(ctx_.get(), CalibrationOptions()); OwningOpRef module_op = ParseModuleOpString(R"mlir( module attributes {} { func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> attributes {} { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc index dcde93f81d60e2..c756505f9c52ac 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" @@ -77,7 +78,6 @@ using ::tensorflow::quantization::CalibrationOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PreprocessAndFreezeGraph; using ::tensorflow::quantization::PyFunctionLibrary; -using ::tensorflow::quantization::RepresentativeDatasetFile; using ::tensorflow::quantization::RunPasses; using ::tensorflow::quantization::UnfreezeConstantsAndSaveVariables; @@ -105,7 +105,8 @@ absl::StatusOr> RunExportPasses( } if (absl::Status pass_run_status = RunPasses( - /*name=*/export_opts.debug_name, + /*name=*/ + export_opts.debug_name, /*add_passes_func=*/ [dup_constants = export_opts.duplicate_shape_determining_constants]( PassManager& pm) { AddExportPasses(pm, dup_constants); }, @@ -229,9 +230,6 @@ StaticRangePtqComponent::StaticRangePtqComponent( const absl::string_view src_saved_model_path, std::vector signature_keys, std::unordered_set tags, - absl::flat_hash_map - representative_dataset_file_map, absl::flat_hash_map signature_def_map, absl::flat_hash_map function_aliases) : ctx_(ctx) { @@ -244,7 +242,7 @@ StaticRangePtqComponent::StaticRangePtqComponent( ctx_, py_function_library, src_saved_model_path, std::move(function_aliases), std::move(tags), std::move(signature_def_map), std::move(signature_keys), - std::move(representative_dataset_file_map), calibration_options); + calibration_options); sub_components_[2] = std::make_unique(ctx_); } @@ -267,9 +265,7 @@ absl::Status QuantizeStaticRangePtq( const std::vector& signature_keys, const absl::flat_hash_map& signature_def_map, const absl::flat_hash_map& function_aliases, - const PyFunctionLibrary& py_function_library, - const absl::flat_hash_map& - representative_dataset_file_map) { + const PyFunctionLibrary& py_function_library) { std::unordered_set tags; tags.insert(quantization_config.tf_saved_model().tags().begin(), quantization_config.tf_saved_model().tags().end()); @@ -283,8 +279,7 @@ absl::Status QuantizeStaticRangePtq( StaticRangePtqComponent static_range_ptq_component( ctx.get(), &py_function_library, src_saved_model_path, signature_keys, - tags, representative_dataset_file_map, signature_def_map, - function_aliases); + tags, signature_def_map, function_aliases); TF_ASSIGN_OR_RETURN(module_op, static_range_ptq_component.Run( module_op, quantization_config)); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h index 93697101e42178..4f2867b034cba3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h @@ -57,9 +57,6 @@ class StaticRangePtqComponent : public Component { absl::string_view src_saved_model_path, std::vector signature_keys, std::unordered_set tags, - absl::flat_hash_map - representative_dataset_file_map, absl::flat_hash_map signature_def_map, absl::flat_hash_map function_aliases); @@ -101,10 +98,7 @@ absl::Status QuantizeStaticRangePtq( const absl::flat_hash_map& signature_def_map, const absl::flat_hash_map& function_aliases, - const tensorflow::quantization::PyFunctionLibrary& py_function_library, - const absl::flat_hash_map< - std::string, tensorflow::quantization::RepresentativeDatasetFile>& - representative_dataset_file_map); + const tensorflow::quantization::PyFunctionLibrary& py_function_library); // LINT.ThenChange(../python/pywrap_quantization.cc:static_range_ptq) } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD index 2138a04d2bb389..b7d75897cddd07 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( @@ -17,6 +18,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", @@ -26,3 +28,20 @@ cc_library( "@stablehlo//:stablehlo_ops", ], ) + +tf_cc_test( + name = "stablehlo_op_quant_spec_test", + srcs = ["stablehlo_op_quant_spec_test.cc"], + deps = [ + ":stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@stablehlo//:stablehlo_ops", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc index 9e832650c24a23..bbcff2dcdbe6d2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -59,8 +60,8 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { quant::GetUniformQuantizedTypeForBias}; } } - for (auto quantizable_operand : spec->coeff_op_quant_dim) { - spec->quantizable_operands.insert(quantizable_operand.first); + for (const auto [operand_idx, per_channel_dim] : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(operand_idx); } } return spec; @@ -69,18 +70,17 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { std::unique_ptr GetStableHloQuantScaleSpec(Operation* op) { auto scale_spec = std::make_unique(); if (llvm::isa(op)) { + mlir::stablehlo::ConcatenateOp, mlir::stablehlo::GatherOp, + mlir::stablehlo::PadOp, mlir::stablehlo::ReduceWindowOp, + mlir::stablehlo::ReshapeOp, mlir::stablehlo::SelectOp, + mlir::stablehlo::SliceOp, mlir::stablehlo::TransposeOp>(op)) { scale_spec->has_same_scale_requirement = true; } return scale_spec; } bool IsOpQuantizableStableHlo(Operation* op) { - if (mlir::isa(op)) { + if (isa(op)) { // Constant ops do not have QuantizableResult attribute but can be // quantized. return true; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc similarity index 56% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc index a469bd7c349c99..80a6d2fa451e6c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec_test.cc @@ -15,33 +15,28 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" +#include #include #include "absl/strings/string_view.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/test_base.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" namespace mlir::quant::stablehlo { namespace { -using ::mlir::quant::QuantizationTestBase; +using ::testing::NotNull; class IsOpQuantizableStableHloTest : public QuantizationTestBase {}; // Quantizable ops: constants // Non-quantizable ops: normal StableHLO ops and terminators -constexpr absl::string_view module_constant_add = R"mlir( +constexpr absl::string_view kModuleConstantAdd = R"mlir( module { func.func @constant_add() -> (tensor<3x2xf32>) { %cst1 = stablehlo.constant dense<2.4> : tensor<3x2xf32> @@ -55,7 +50,7 @@ constexpr absl::string_view module_constant_add = R"mlir( // Quantizable ops: XlaCallModule op with "fully_quantizable" attribute and // same-scale StableHLO ops // Non-quantizable ops: quantize/dequantize ops -constexpr absl::string_view module_composite_same_scale = R"mlir( +constexpr absl::string_view kModuleCompositeSameScale = R"mlir( module { func.func @same_scale_after_composite() -> tensor<3x1xf32> { %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> @@ -70,7 +65,7 @@ constexpr absl::string_view module_composite_same_scale = R"mlir( )mlir"; // Non-quantizable ops: XlaCallModule op without "fully_quantizable" attribute -constexpr absl::string_view module_composite_no_attr = R"mlir( +constexpr absl::string_view kModuleCompositeNoAttr = R"mlir( module { func.func @composite_without_attr() -> tensor<1x3xf32> { %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @non_quantizable_composite, _original_entry_function = "non_quantizable_composite", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> @@ -80,97 +75,79 @@ constexpr absl::string_view module_composite_no_attr = R"mlir( )mlir"; TEST_F(IsOpQuantizableStableHloTest, ConstantOpQuantizable) { - OwningOpRef module_op_ref = - ParseModuleOpString(module_constant_add); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "constant_add"); - Operation* constant_op = - FindOperationOfType(test_func); - bool is_constant_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(constant_op); + OwningOpRef module_op_ref = ParseModuleOpString(kModuleConstantAdd); + auto test_func = module_op_ref->lookupSymbol("constant_add"); + ASSERT_THAT(test_func, NotNull()); - EXPECT_TRUE(is_constant_quantizable); + auto constant_op = + FindOperationOfType(test_func); + EXPECT_TRUE(IsOpQuantizableStableHlo(constant_op)); } TEST_F(IsOpQuantizableStableHloTest, TerminatorOpNotQuantizable) { - OwningOpRef module_op_ref = - ParseModuleOpString(module_constant_add); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "constant_add"); - Operation* return_op = FindOperationOfType(test_func); - bool is_return_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(return_op); - - EXPECT_FALSE(is_return_quantizable); + OwningOpRef module_op_ref = ParseModuleOpString(kModuleConstantAdd); + auto test_func = module_op_ref->lookupSymbol("constant_add"); + ASSERT_THAT(test_func, NotNull()); + + auto return_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(return_op)); } TEST_F(IsOpQuantizableStableHloTest, SameScaleOpQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_same_scale); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); - Operation* reshape_op = - FindOperationOfType(test_func); - bool is_reshape_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(reshape_op); - - EXPECT_TRUE(is_reshape_quantizable); + ParseModuleOpString(kModuleCompositeSameScale); + auto test_func = + module_op_ref->lookupSymbol("same_scale_after_composite"); + ASSERT_THAT(test_func, NotNull()); + + auto reshape_op = FindOperationOfType(test_func); + EXPECT_TRUE(IsOpQuantizableStableHlo(reshape_op)); } TEST_F(IsOpQuantizableStableHloTest, NonSameScaleOpNotQuantizable) { - OwningOpRef module_op_ref = - ParseModuleOpString(module_constant_add); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "constant_add"); - Operation* add_op = FindOperationOfType(test_func); - bool is_add_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(add_op); - - EXPECT_FALSE(is_add_quantizable); + OwningOpRef module_op_ref = ParseModuleOpString(kModuleConstantAdd); + auto test_func = module_op_ref->lookupSymbol("constant_add"); + ASSERT_THAT(test_func, NotNull()); + + auto add_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(add_op)); } TEST_F(IsOpQuantizableStableHloTest, ValidXlaCallModuleOpQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_same_scale); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); - Operation* xla_call_module_op = - FindOperationOfType(test_func); - bool is_xla_call_module_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); - - EXPECT_TRUE(is_xla_call_module_quantizable); + ParseModuleOpString(kModuleCompositeSameScale); + auto test_func = + module_op_ref->lookupSymbol("same_scale_after_composite"); + ASSERT_THAT(test_func, NotNull()); + + auto xla_call_module_op = FindOperationOfType(test_func); + EXPECT_TRUE(IsOpQuantizableStableHlo(xla_call_module_op)); } TEST_F(IsOpQuantizableStableHloTest, InvalidXlaCallModuleOpNotQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_no_attr); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "composite_without_attr"); - Operation* xla_call_module_op = - FindOperationOfType(test_func); - bool is_xla_call_module_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); - - EXPECT_FALSE(is_xla_call_module_quantizable); + ParseModuleOpString(kModuleCompositeNoAttr); + auto test_func = + module_op_ref->lookupSymbol("composite_without_attr"); + ASSERT_THAT(test_func, NotNull()); + + auto xla_call_module_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(xla_call_module_op)); } TEST_F(IsOpQuantizableStableHloTest, QuantizeDequantizeOpNotQuantizable) { OwningOpRef module_op_ref = - ParseModuleOpString(module_composite_same_scale); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); - Operation* quantize_op = - FindOperationOfType(test_func); - Operation* dequantize_op = - FindOperationOfType(test_func); - bool is_quantize_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(quantize_op); - bool is_dequantize_quantizable = - mlir::quant::stablehlo::IsOpQuantizableStableHlo(dequantize_op); + ParseModuleOpString(kModuleCompositeSameScale); + auto test_func = + module_op_ref->lookupSymbol("same_scale_after_composite"); + ASSERT_THAT(test_func, NotNull()); + + auto quantize_op = FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(quantize_op)); - EXPECT_FALSE(is_quantize_quantizable); - EXPECT_FALSE(is_dequantize_quantizable); + auto dequantize_op = + FindOperationOfType(test_func); + EXPECT_FALSE(IsOpQuantizableStableHlo(dequantize_op)); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 633c193ab57c7c..f572a0795e3b77 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -89,7 +89,7 @@ UniformQuantizedPerAxisType GetPerChannelType(QuantType quant_type) { void GetQuantizationParams(OpBuilder &builder, Location loc, QuantType quant_type, Value &scales, Value &zero_points, bool output_zero_point_in_fp, - DenseIntElementsAttr &broadcast_dims) { + DenseI64ArrayAttr &broadcast_dims) { // Get scales/zero points for per-tensor and per-axis quantization cases. if (auto *quant_per_tensor_type = std::get_if(&quant_type)) { @@ -140,8 +140,8 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, builder.getI32Type()), zero_points_vec)); } - broadcast_dims = DenseIntElementsAttr::get( - RankedTensorType::get({1}, builder.getI64Type()), + broadcast_dims = DenseI64ArrayAttr::get( + builder.getContext(), {static_cast(quant_per_channel_type.getQuantizedDimension())}); } } @@ -256,9 +256,8 @@ Value ApplyMergedScalesAndZps(OpBuilder &builder, Location loc, merged_scale_double.end()), merged_zp_float(merged_zp_double.begin(), merged_zp_double.end()); - auto broadcast_dims = DenseIntElementsAttr::get( - RankedTensorType::get({1}, builder.getI64Type()), - {quantized_dimension}); + auto broadcast_dims = + DenseI64ArrayAttr::get(builder.getContext(), {quantized_dimension}); Value merged_scale = builder.create( loc, DenseFPElementsAttr::get( RankedTensorType::get({channel_size}, builder.getF32Type()), @@ -367,7 +366,7 @@ class ConvertUniformQuantizeOp ConversionPatternRewriter &rewriter, QuantType quant_type) const { Value scales, zero_points; - DenseIntElementsAttr broadcast_dims; + DenseI64ArrayAttr broadcast_dims; GetQuantizationParams(rewriter, op->getLoc(), quant_type, scales, zero_points, /*output_zero_point_in_fp=*/true, broadcast_dims); @@ -425,7 +424,7 @@ class ConvertUniformDequantizeOp return failure(); } Value scales, zero_points; - DenseIntElementsAttr broadcast_dims; + DenseI64ArrayAttr broadcast_dims; GetQuantizationParams(rewriter, op->getLoc(), *quant_type, scales, zero_points, /*output_zero_point_in_fp=*/false, broadcast_dims); @@ -465,15 +464,41 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { // We only handle cases where lhs, rhs and results all have quantized // element type. - if (failed(lhs_quant_type) || IsPerChannelType(*lhs_quant_type) || - failed(rhs_quant_type) || IsPerChannelType(*rhs_quant_type) || - failed(res_quant_type) || IsPerChannelType(*res_quant_type)) { + if (failed(lhs_quant_type) || failed(rhs_quant_type) || + failed(res_quant_type)) { op->emitError( - "AddOp requires the same quantized element type for all operands and " + "AddOp requires the quantized element type for all operands and " "results"); return failure(); } + if (IsPerChannelType(*lhs_quant_type) || + IsPerChannelType(*rhs_quant_type) || + IsPerChannelType(*res_quant_type)) { + // Handle Per-Channel Quantized Types. We only support lhs/rhs/result with + // exact same per-channel quantized types with I32 storage type. + if (!IsPerChannelType(*lhs_quant_type) || + !IsPerChannelType(*rhs_quant_type) || + !IsPerChannelType(*res_quant_type) || + GetPerChannelType(*lhs_quant_type) != + GetPerChannelType(*rhs_quant_type) || + GetPerChannelType(*lhs_quant_type) != + GetPerChannelType(*res_quant_type)) { + op->emitError( + "Per-channel quantized AddOp requires the same quantized element " + "type for all operands and results"); + return failure(); + } + if (!GetPerChannelType(*lhs_quant_type).getStorageType().isInteger(32)) { + // For server-side StableHLO Quantization, add is quantized only when + // fused with conv/dot ops, whose output must be i32. + op->emitError("Per-channel quantized AddOp requires i32 storage type"); + return failure(); + } + return matchAndRewritePerChannel(op, adaptor, rewriter, + GetPerChannelType(*lhs_quant_type)); + } + // TODO: b/260280919 - Consider avoiding conversion to int32. auto res_int32_tensor_type = op.getResult().getType().clone(rewriter.getI32Type()); @@ -536,6 +561,33 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { return success(); } + + LogicalResult matchAndRewritePerChannel( + mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + UniformQuantizedPerAxisType quant_type) const { + // We assume lhs/rhs/result have the same quantized type with i32 storage. + Value add_result = rewriter.create( + op->getLoc(), adaptor.getLhs(), adaptor.getRhs()); + // Add zp contribution if it is non-zero for any channel. + if (llvm::any_of(quant_type.getZeroPoints(), + [](int64_t zp) { return zp != 0; })) { + SmallVector zps_vec(quant_type.getZeroPoints().begin(), + quant_type.getZeroPoints().end()); + Value zps = rewriter.create( + op->getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(zps_vec.size())}, + rewriter.getI32Type()), + zps_vec)); + add_result = rewriter.create( + op->getLoc(), add_result, zps, + rewriter.getDenseI64ArrayAttr( + {static_cast(quant_type.getQuantizedDimension())})); + } + rewriter.replaceOp(op, add_result); + return success(); + } }; // This is a convenient struct for holding dimension numbers for dot-like ops diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc index a24db896d79f1d..a6312e067af50f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc @@ -596,7 +596,7 @@ class ConvertUniformQuantizedAddOp // rhs (bias) is always 1D that broadcasts to the last dim of lhs. auto broadcast_dims = - mhlo::GetI64ElementsAttr({lhs_type.getRank() - 1}, &rewriter); + rewriter.getDenseI64ArrayAttr({lhs_type.getRank() - 1}); auto rhs_type = GetUniformQuantizedType( op, op.getRhs().getType(), op.getRhsScales(), op.getRhsZeroPoints(), @@ -651,8 +651,7 @@ class ConvertUniformQuantizedClipByValueOp if (quantization_axis >= 0) { broadcast_dims_values.push_back(quantization_axis); } - auto broadcast_dims = - mhlo::GetI64ElementsAttr(broadcast_dims_values, &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr(broadcast_dims_values); auto min_max_type = GetUniformQuantizedType( op, op.getMin().getType(), op.getScales(), op.getZeroPoints(), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc index 4c20b6bebdcdad..93946fdc320a97 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h" #include "xla/client/client_library.h" #include "xla/shape.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" @@ -45,7 +45,7 @@ class LegalizeTFQuantTest : public Test { tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; mlir_to_hlo_args.mlir_module = mlir_module_string; tensorflow::se::Platform* platform = - tensorflow::se::MultiPlatformManager::PlatformWithName("Host").value(); + tensorflow::se::PlatformManager::PlatformWithName("Host").value(); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); tensorflow::tpu::TPUCompileMetadataProto metadata_proto; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc index 3e5b7e1f8d5ace..17300611e356ce 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc @@ -47,7 +47,29 @@ class BFloat16TypeConverter : public TypeConverter { } }; -// An Op is illegal iff it is non-UQ op and it contains qint types. +// This helper function makes legality check easier. Both convert ops in the +// patterns below are considered legal: +// - BitcastConvertOp(i32 -> f32) + ConvertOp(f32 -> bf16) +// - ConvertOp(bf16 -> f32) -> BitcastConvertOp(f32 -> i32) +template +bool IsConvertOpLegal(ConvertOp convert_op, BFloat16TypeConverter &converter) { + if (!converter.isLegal(convert_op.getOperand().getType())) { + auto other_convert_op = dyn_cast_or_null( + convert_op.getOperand().getDefiningOp()); + return other_convert_op && + converter.isLegal(other_convert_op.getOperand().getType()); + } else if (!converter.isLegal(convert_op.getResult().getType())) { + if (!convert_op.getResult().hasOneUse()) { + return false; + } + auto other_convert_op = dyn_cast_or_null( + *convert_op.getResult().getUsers().begin()); + return other_convert_op && + converter.isLegal(other_convert_op.getResult().getType()); + } + return true; +} + class BFloat16TypeConversionTarget : public ConversionTarget { public: explicit BFloat16TypeConversionTarget(MLIRContext &ctx, @@ -58,6 +80,15 @@ class BFloat16TypeConversionTarget : public ConversionTarget { // types do not contain. if (auto func = dyn_cast(op)) { if (!converter_.isSignatureLegal(func.getFunctionType())) return false; + } else if (auto bitcast_convert_op = + dyn_cast(op)) { + return IsConvertOpLegal(bitcast_convert_op, + converter_); + } else if (auto convert_op = dyn_cast(op)) { + return IsConvertOpLegal(convert_op, + converter_); } return converter_.isLegal(op); }); @@ -69,7 +100,7 @@ class BFloat16TypeConversionTarget : public ConversionTarget { class BFloat16TypePattern : public ConversionPattern { public: - BFloat16TypePattern(MLIRContext *ctx, TypeConverter &converter) + BFloat16TypePattern(TypeConverter &converter, MLIRContext *ctx) : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite( @@ -78,6 +109,10 @@ class BFloat16TypePattern : public ConversionPattern { if (getTypeConverter()->isLegal(op)) { return failure(); } + if (isa(op)) { + // Skip BitcastConvertOp, which is handled by the other pattern. + return failure(); + } // Update the results. SmallVector new_results; @@ -118,6 +153,42 @@ class BFloat16TypePattern : public ConversionPattern { return success(); } }; + +class BitcastConvertOpPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::stablehlo::BitcastConvertOp op, + mlir::stablehlo::BitcastConvertOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + bool is_input_legal = + getTypeConverter()->isLegal(op.getOperand().getType()); + bool is_output_legal = + getTypeConverter()->isLegal(op.getResult().getType()); + if (is_input_legal && is_output_legal) { + return failure(); + } else if (is_input_legal) { + // output is f32, we bitcast_convert to f32 and then convert to bf16. + Value output = rewriter.create( + op->getLoc(), op.getResult().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getResult().getType()), + output); + } else if (is_output_legal) { + // input is f32, we convert from bf16 and then bitcast_convert. + Value output = rewriter.create( + op->getLoc(), op.getOperand().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), output); + } else { + // Both input/output are f32. Convert to no-op. + rewriter.replaceOp(op, adaptor.getOperand()); + } + return success(); + } +}; } // namespace #define GEN_PASS_DEF_CONVERTFUNCTOBFLOAT16PASS @@ -140,7 +211,8 @@ void ConvertFuncToBfloat16Pass::runOnOperation() { RewritePatternSet patterns(context); BFloat16TypeConverter converter; - patterns.add(context, converter); + patterns.add(converter, + context); populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); BFloat16TypeConversionTarget target(*context, converter); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc index a5f35c9a64b5bd..5f69cd7ee3e082 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep @@ -43,22 +45,26 @@ limitations under the License. namespace mlir::quant::stablehlo { -FailureOr ConvertSerializedStableHloModuleToBfloat16( - MLIRContext* context, StringRef serialized_stablehlo_module) { +absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( + StringRef serialized_stablehlo_module) { // StableHLO module is empty often because the XlaCallModuleOp is already // deserialized, e.g. after invoking XlaCallModuleDeserializationPass. We // don't handle this situation. - if (serialized_stablehlo_module.empty()) return failure(); + if (serialized_stablehlo_module.empty()) { + return absl::InvalidArgumentError("StableHLO module is empty."); + } + MLIRContext context; OwningOpRef stablehlo_module_op = mlir::stablehlo::deserializePortableArtifact(serialized_stablehlo_module, - context); + &context); // Convert the StableHLO module to bfloat16. - PassManager pm(context); + PassManager pm(&context); pm.addNestedPass(createConvertFuncToBfloat16Pass()); if (failed(pm.run(stablehlo_module_op.get()))) { - return failure(); + return absl::InternalError( + "Failed to convert StableHLO module to bfloat16."); } std::string bytecode; @@ -66,7 +72,7 @@ FailureOr ConvertSerializedStableHloModuleToBfloat16( if (failed(mlir::stablehlo::serializePortableArtifact( stablehlo_module_op.get(), mlir::stablehlo::getCurrentVersion(), os))) { - return failure(); + return absl::InternalError("Failed to serialize StableHLO module."); } return bytecode; } @@ -95,9 +101,11 @@ void ConvertXlaCallModuleOpToBfloat16Pass::runOnOperation() { auto result = func_op->walk([&](TF::XlaCallModuleOp op) { // Converts the serialized StableHLO module to bfloat16. - auto result = ConvertSerializedStableHloModuleToBfloat16( - &getContext(), op.getModuleAttr()); - if (failed(result)) { + auto result = + ConvertSerializedStableHloModuleToBfloat16(op.getModuleAttr()); + if (!result.ok()) { + llvm::errs() << "Failed to convert StableHLO module to bfloat16: " + << result.status().message(); return WalkResult::interrupt(); } op.setModuleAttr(StringAttr::get(&getContext(), *result)); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 6f13634b317aa4..dbe88208ae7b03 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -31,6 +35,11 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/regexp.h" // IWYU pragma: keep + +#define DEBUG_TYPE "lift_quantizable_spots_as_functions" namespace mlir::quant::stablehlo { @@ -39,13 +48,16 @@ namespace mlir::quant::stablehlo { namespace { +using ::stablehlo::quantization::FunctionNameMatcherSpec; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizationSpec; +using ::stablehlo::quantization::QuantizationSpecs; + // TODO - b/303543789: Move the helper functions below to a separate util. // Fetches the default or null attribute, used for pattern matching. Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { - if (!attr) { - return builder.getStringAttr(kNullAttributeValue); - } - return attr; + if (attr) return attr; + return builder.getStringAttr(kNullAttributeValue); } // Checks whether the value of a constant equals the given float, regardless @@ -62,6 +74,12 @@ bool FloatValueEquals(const Attribute& attr, const double value) { }); } +// Lifts quantizable units as separate functions, thereby identifying the +// boundaries of quantizable subgraphs. `QuantizationSpecs` influences how +// quantizable units are lifted. +// +// FileCheck test cases using various `QuantizationSpecs` can be seen at +// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. class LiftQuantizableSpotsAsFunctionsPass : public impl::LiftQuantizableSpotsAsFunctionsPassBase< LiftQuantizableSpotsAsFunctionsPass> { @@ -69,10 +87,19 @@ class LiftQuantizableSpotsAsFunctionsPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( LiftQuantizableSpotsAsFunctionsPass) - explicit LiftQuantizableSpotsAsFunctionsPass() = default; + LiftQuantizableSpotsAsFunctionsPass() = default; + + // Constructor with explicit user-provided `QuantizationSpecs`. + explicit LiftQuantizableSpotsAsFunctionsPass( + QuantizationSpecs quantization_specs) + : quantization_specs_(std::move(quantization_specs)) {} private: void runOnOperation() override; + + // No explicit quantization spec is specified by default. Implicitly this + // means that all quantizable units will be identified and lifted. + QuantizationSpecs quantization_specs_{}; }; namespace simple_patterns { @@ -83,6 +110,91 @@ namespace fusion_patterns { #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.inc" } +// Returns a `func::FuncOp` in `module_op` (not nested) whose name matches +// `name`. Returns null if no such a function exists. +// TODO: b/307620778 - Factor out "FindMainFuncOp" functionality. +func::FuncOp FindFuncOp(ModuleOp module_op, const StringRef name) { + auto func_ops = module_op.getOps(); + auto func_itr = llvm::find_if(func_ops, [name](func::FuncOp func_op) { + return func_op.getName() == name; + }); + + if (func_itr == func_ops.end()) return {}; + return *func_itr; +} + +// Quantizable Unit matcher that uses lifted function's name for matching. +class FunctionNameMatcher { + public: + explicit FunctionNameMatcher(const FunctionNameMatcherSpec& spec) + : match_regex_(GetMatchRegex(spec)) {} + + // Returns `true` when matched with the entry function of + // `xla_call_module_op`. + bool Match(TF::XlaCallModuleOp xla_call_module_op) const { + if (match_regex_ == nullptr) return false; + + const std::string lifted_func_name = + xla_call_module_op->getAttrOfType("_entry_function") + .getValue() + .str(); + + return RE2::FullMatch(lifted_func_name, *match_regex_); // NOLINT + } + + private: + // Returns an owned `RE2` object that corresponds to the `spec`. Returns + // `nullptr` if the `spec` is invalid. + // NOLINTNEXTLINE - RE2 included via TSL regexp.h + std::unique_ptr GetMatchRegex(const FunctionNameMatcherSpec& spec) { + const std::string& regex = spec.regex(); + if (regex.empty()) return nullptr; + + return std::make_unique(regex); // NOLINT + } + + // Regex object used for matching against a lifted function's name. + std::unique_ptr match_regex_; // NOLINT +}; + +// Applies quantization spec to all matched lifted functions. At this point only +// denylisting (`NoQuantization`) will be applied if specs is nonempty. +// TODO: b/307620778 - Support more advanced selective quantization methods. +LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec, + ModuleOp module_op) { + func::FuncOp main_func = FindFuncOp(module_op, "main"); + if (!main_func) return failure(); + + const Method& quantization_method = spec.method(); + if (!quantization_method.has_no_quantization()) { + module_op->emitError() << "Unsupported quantization method: " + << quantization_method.DebugString() << "\n"; + return failure(); + } + + const FunctionNameMatcher matcher(spec.matcher().function_name()); + for (auto xla_call_module_op : main_func.getOps()) { + if (!matcher.Match(xla_call_module_op)) continue; + + // Disable quantization when matched. + const std::string lifted_func_name = + xla_call_module_op->getAttrOfType("_entry_function") + .getValue() + .str(); + func::FuncOp lifted_func = FindFuncOp(module_op, lifted_func_name); + + // Remove relevant attributes that enable quantization. This essentially + // disables quantization for the matched `xla_call_module_op`. + xla_call_module_op->removeAttr("_original_entry_function"); + xla_call_module_op->removeAttr("_tfl_quant_trait"); + lifted_func->removeAttr("tf_quant.composite_function"); + + LLVM_DEBUG(llvm::dbgs() << "Disabled quantization for quantizable unit: " + << lifted_func_name << "\n"); + } + return success(); +} + void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { MLIRContext* ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -101,8 +213,26 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { // Remove all attr_map attributes. module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); }); + + // Perform selective quantization. Iterates over the quantization specs and + // applies quantization methods to each matched lifted function. + for (const QuantizationSpec& spec : quantization_specs_.specs()) { + if (failed(ApplyQuantizationSpec(spec, module_op))) { + signalPassFailure(); + return; + } + } } } // namespace +// Creates `LiftQuantizableSpotsAsFunctionsPass` with user-defined +// `QuantizationSpecs`. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const QuantizationSpecs& quantization_specs) { + return std::make_unique( + quantization_specs); +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td index 6e134a60e3816e..a0bc228397465b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td @@ -27,12 +27,25 @@ include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" // Pattern rules for lifting ops with bias as functions //===----------------------------------------------------------------------===// +def LiftDotGeneralWithBiasSameShape : Pat< + (StableHLO_AddOp:$res + (StableHLO_DotGeneralOp + $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $bias), + (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_fn"> + (ArgumentList $lhs, $rhs, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>; + def LiftConvWithBias : Pat< (StableHLO_AddOp:$res (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, $feature_group_count, $batch_group_count, $precision_config), - $bias), + (StableHLO_BroadcastInDimOp $bias, $dims)), (LiftAsTFXlaCallModule<"composite_conv_with_bias_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), @@ -52,7 +65,7 @@ def LiftDotGeneralWithBias : Pat< (StableHLO_AddOp:$res (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config), - $bias), + (StableHLO_BroadcastInDimOp $bias, $dims)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), @@ -216,13 +229,29 @@ def LiftDotGeneralWithRelu6 : Pat< // Pattern rules for lifting ops with bias and activation as functions //===----------------------------------------------------------------------===// +def LiftDotGeneralWithBiasSameShapeAndRelu : Pat< + (StableHLO_MaxOp:$res + (StableHLO_AddOp + (StableHLO_DotGeneralOp + $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $bias), + (StableHLO_ConstantOp $cst)), + (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_and_relu_fn"> + (ArgumentList $lhs, $rhs, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + [(IsNotInLiftedFunc $res), + (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; + def LiftConvWithBiasAndRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, $feature_group_count, $batch_group_count, $precision_config), - $bias), + (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_conv_with_bias_and_relu_fn"> (ArgumentList $lhs, $rhs, $bias), @@ -245,7 +274,7 @@ def LiftDotGeneralWithBiasAndRelu : Pat< (StableHLO_AddOp (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config), - $bias), + (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu_fn"> (ArgumentList $lhs, $rhs, $bias), @@ -303,6 +332,21 @@ def LiftDotGeneralWithBiasAndReluDynamic : Pat< [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 15)>; +def LiftDotGeneralWithBiasSameShapeAndRelu6 : Pat< + (StableHLO_ClampOp:$res + (StableHLO_ConstantOp $cst_0), + (StableHLO_AddOp + (StableHLO_DotGeneralOp + $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $bias), + (StableHLO_ConstantOp $cst_1)), + (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_and_relu6_fn"> + (ArgumentList $lhs, $rhs, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>; def LiftConvWithBiasAndRelu6 : Pat< (StableHLO_ClampOp:$res @@ -311,7 +355,7 @@ def LiftConvWithBiasAndRelu6 : Pat< (StableHLO_ConvolutionOp $lhs, $rhs, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $dimension_numbers, $feature_group_count, $batch_group_count, $precision_config), - $bias), + (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_conv_with_bias_and_relu6_fn"> (ArgumentList $lhs, $rhs, $bias), @@ -334,7 +378,7 @@ def LiftDotGeneralWithBiasAndRelu6 : Pat< (StableHLO_AddOp (StableHLO_DotGeneralOp $lhs, $rhs, $dot_dimension_numbers, $precision_config), - $bias), + (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu6_fn"> (ArgumentList $lhs, $rhs, $bias), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h index d89b044b189ec0..5bb3e58be01d58 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -19,13 +19,13 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" namespace mlir::quant::stablehlo { @@ -33,22 +33,22 @@ namespace mlir::quant::stablehlo { // Creates a `QuantizePass` that quantizes ops according to surrounding qcast / // dcast ops. std::unique_ptr> CreateQuantizePass( - const quant::QuantizationSpecs& quantization_specs); + const quant::QuantizationSpecs& quantization_specs, + bool enable_per_channel_quantized_weight = true); // Creates a pass that quantizes weight component of StableHLO graph. std::unique_ptr> CreateQuantizeWeightPass( const ::stablehlo::quantization::QuantizationComponentSpec& quantization_component_spec = {}); -// Creates an instance of the StableHLO dialect PrepareQuantize pass without any -// arguments. Preset method of SRQ is set to the quantization option by default. -std::unique_ptr> CreatePrepareQuantizePass( - bool enable_per_channel_quantization = false, int bit_width = 8); - // Converts a serialized StableHLO module to bfloat16 and output serialized // module. -FailureOr ConvertSerializedStableHloModuleToBfloat16( - MLIRContext* context, StringRef serialized_stablehlo_module); +absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( + StringRef serialized_stablehlo_module); + +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs); // Adds generated pass default constructors or options definitions. #define GEN_PASS_DECL diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 3e1c431b6e6fd5..dcebf70eb5c1e9 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -21,25 +21,6 @@ def QuantizeWeightPass : Pass<"stablehlo-quantize-weight", "mlir::func::FuncOp"> let constructor = "mlir::quant::stablehlo::CreateQuantizeWeightPass()"; } -def PrepareQuantizePass : Pass<"stablehlo-prepare-quantize", "mlir::func::FuncOp"> { - let summary = "Prepare StableHLO dialect for static range quantization."; - let options = [ - Option<"enable_per_channel_quantization_", - "enable-per-channel-quantization", - "bool", /*default=*/"true", - "Whether enable per-channel quantized weights.">, - Option<"bit_width_", "bit-width", "int", /*default=*/"8", - "Bitwidth of quantized integer"> - ]; - let constructor = "mlir::quant::stablehlo::CreatePrepareQuantizePass()"; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", - "mlir::quant::QuantizationDialect", - "mlir::quantfork::QuantizationForkDialect", - "mlir::arith::ArithDialect", - ]; -} - def UnfuseMhloBatchNormPass : Pass<"stablehlo-unfuse-mhlo-batch-norm", "mlir::func::FuncOp"> { let summary = "Unfuses batch normalization into arithmetic ops."; } @@ -53,6 +34,7 @@ def LiftQuantizableSpotsAsFunctionsPass : Pass<"stablehlo-lift-quantizable-spots that disperse values. (ex: convolution, dot_general) }]; let dependentDialects = [ + "mlir::func::FuncDialect", "mlir::stablehlo::StablehloDialect", "TF::TensorFlowDialect", ]; @@ -67,40 +49,62 @@ def ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass : Pass<"stablehlo- }]; } -def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { - let summary = "Applies static-range quantization on ops."; +def RestoreFunctionNamePass : Pass<"stablehlo-restore-function-name", "ModuleOp"> { + let summary = "Restores function name from XlaCallModule op."; +} + +def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-functions", "ModuleOp"> { + let summary = "Quantize composite functions with QDQ input / outputs."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"mlir_dump_file_name_", "mlir-dump-file-name", + "std::optional", /*default=*/"std::nullopt", + "MLIR dump file name."> + ]; let dependentDialects = [ + "mlir::arith::ArithDialect", "mlir::stablehlo::StablehloDialect", "mlir::quant::QuantizationDialect", "mlir::quantfork::QuantizationForkDialect", + "TF::TensorFlowDialect", ]; } -def RestoreFunctionNamePass : Pass<"stablehlo-restore-function-name", "ModuleOp"> { - let summary = "Restores function name from XlaCallModule op."; +def PrepareQuantizePass : Pass<"stablehlo-prepare-quantize", "mlir::func::FuncOp"> { + let summary = "Prepare StableHLO dialect for static range quantization by converting quantfork.stats into quantfork.qcast and dcast ops."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"bit_width_", "bit-width", "int", /*default=*/"8", + "Bitwidth of quantized integer"> + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantizationDialect", + "mlir::quantfork::QuantizationForkDialect", + "mlir::arith::ArithDialect", + ]; } -def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> { - let summary = "Apply clean-up after quantization."; +def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { + let summary = "Applies static-range quantization on ops by converting quantfork.qcast, quantfork.dcast, and float op into uniform quantized ops ."; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantizationDialect", "mlir::quantfork::QuantizationForkDialect", ]; } -def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-functions", "ModuleOp"> { - let summary = "Quantize composite functions with QDQ input / outputs."; - let options = [ - Option<"mlir_dump_file_name_", "mlir-dump-file-name", - "std::optional", /*default=*/"std::nullopt", - "MLIR dump file name."> - ]; +def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> { + let summary = "Apply clean-up after quantization."; let dependentDialects = [ - "mlir::arith::ArithDialect", "mlir::stablehlo::StablehloDialect", - "mlir::quant::QuantizationDialect", "mlir::quantfork::QuantizationForkDialect", - "TF::TensorFlowDialect", ]; } @@ -124,11 +128,6 @@ def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"stablehlo-convert-xla-call-modu ]; } -def PopulateShapePass : Pass<"populate-shape", "ModuleOp"> { - let summary = "Populate output shape with known information for CustomAggregatorOp and XlaCallModuleOp."; - let dependentDialects = ["TF::TensorFlowDialect"]; -} - def OptimizeGraphPass : Pass<"optimize-graph", "ModuleOp"> { let summary = "Optimize the sub-optimal patterns after quantization."; let dependentDialects = ["mlir::stablehlo::StablehloDialect",]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc deleted file mode 100644 index 0d4f0594f5c7d8..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/populate_shape.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "llvm/Support/Casting.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/TypeRange.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Support/TypeID.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/core/ir/types/dialect.h" - -namespace mlir::quant::stablehlo { - -#define GEN_PASS_DEF_POPULATESHAPEPASS -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" - -namespace { - -class PopulateShapeForCustomAggregatorOp - : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - TF::CustomAggregatorOp op, TF::CustomAggregatorOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto input_shape_type = op.getInput().getType().dyn_cast(); - auto output_shape_type = op.getOutput().getType(); - - if (!input_shape_type.isa()) { - input_shape_type = adaptor.getInput().getType(); - } - - if (input_shape_type.isa() && - !output_shape_type.isa() && - TF::HasCompatibleElementTypes(input_shape_type, output_shape_type)) { - auto new_op = rewriter.create( - op->getLoc(), /*output=*/input_shape_type, - /*args=*/adaptor.getInput(), - /*Id=*/op.getId()); - new_op->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, new_op); - return success(); - } - return failure(); - } -}; - -class PopulateShapeForXlaCallModuleOp - : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - TF::XlaCallModuleOp op, TF::XlaCallModuleOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->getNumResults() != 1) { - op->emitError("XlaCallModuleOp doesn't have 1 output."); - return failure(); - } - // Assume XlaCallModuleOp only has 1 output. - auto output_shape_type = op->getResultTypes()[0]; - if (!output_shape_type.isa()) { - auto output_shape_attr = op.getSout()[0].dyn_cast(); - if (!output_shape_attr.hasRank()) { - return failure(); - } - auto new_output_shape_type = tensorflow::GetTypeFromTFTensorShape( - output_shape_attr.getShape(), - getElementTypeOrSelf(op.getResultTypes()[0])); - auto new_op = rewriter.create( - op->getLoc(), /*output=*/new_output_shape_type, - /*args=*/adaptor.getOperands(), - /*version=*/op.getVersionAttr(), - /*module=*/op.getModuleAttr(), - /*Sout=*/op.getSoutAttr()); - new_op->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, new_op); - return success(); - } - return failure(); - } -}; - -class PopulateShapePass - : public impl::PopulateShapePassBase { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PopulateShapePass) - - explicit PopulateShapePass() = default; - - private: - void runOnOperation() override; -}; - -void PopulateShapePass::runOnOperation() { - Operation *op = getOperation(); - MLIRContext *context = op->getContext(); - RewritePatternSet patterns(context); - ConversionTarget target(*context); - target.addDynamicallyLegalOp([](Operation *op) { - auto custom_aggregator_op = llvm::dyn_cast(op); - return custom_aggregator_op.getInput().getType().isa() && - custom_aggregator_op.getOutput().getType().isa(); - }); - target.addDynamicallyLegalOp([](Operation *op) { - if (op->getNumResults() != 1) return true; - return op->getResultTypes()[0].isa(); - }); - - patterns - .add( - context); - - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { - return signalPassFailure(); - } -} -} // namespace - -} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc index 1291b0f7aa83eb..688e21b7d898dc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -40,11 +41,11 @@ namespace mlir { namespace quant { namespace stablehlo { -namespace { - #define GEN_PASS_DEF_PREPAREQUANTIZEPASS #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" +namespace { + // Applies prepare quantization on the model in TF dialect. This pass runs // before the quantization pass and propagate the quantization parameters // across ops. This step is necessary for post-training quantization and also @@ -53,12 +54,14 @@ namespace { class PrepareQuantizePass : public impl::PrepareQuantizePassBase { public: - PrepareQuantizePass() = default; - PrepareQuantizePass(const PrepareQuantizePass&) = default; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizePass) + + using impl::PrepareQuantizePassBase< + PrepareQuantizePass>::PrepareQuantizePassBase; - explicit PrepareQuantizePass(bool enable_per_channel_quantization, + explicit PrepareQuantizePass(bool enable_per_channel_quantized_weight, int bit_width) { - enable_per_channel_quantization_ = enable_per_channel_quantization; + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; bit_width_ = bit_width; } @@ -162,9 +165,11 @@ void PrepareQuantizePass::runOnOperation() { // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). ApplyQuantizationParamsPropagation( - func, /*is_signed=*/true, bit_width_, !enable_per_channel_quantization_, - GetStableHloOpQuantSpec, GetStableHloQuantScaleSpec, - /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false); + func, /*is_signed=*/true, bit_width_, + !enable_per_channel_quantized_weight_, GetStableHloOpQuantSpec, + GetStableHloQuantScaleSpec, + /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); // Restore constants as stablehlo::ConstantOp. RewritePatternSet patterns_2(ctx); @@ -180,9 +185,9 @@ void PrepareQuantizePass::runOnOperation() { // Creates an instance of the TensorFlow dialect PrepareQuantize pass. std::unique_ptr> CreatePrepareQuantizePass( - bool enable_per_channel_quantization, int bit_width) { - return std::make_unique(enable_per_channel_quantization, - bit_width); + bool enable_per_channel_quantized_weight, int bit_width) { + return std::make_unique( + enable_per_channel_quantized_weight, bit_width); } } // namespace stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 91129477a726af..76430bec75e4ce 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "llvm/ADT/STLExtras.h" @@ -48,8 +49,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -59,12 +60,15 @@ namespace mlir::quant::stablehlo { namespace { +using ::mlir::quant::FindUserOfType; using ::mlir::quant::TryCast; using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::BroadcastInDimOp; using ::mlir::stablehlo::ConcatenateOp; using ::mlir::stablehlo::ConvolutionOp; using ::mlir::stablehlo::DotGeneralOp; using ::mlir::stablehlo::DynamicBroadcastInDimOp; +using ::mlir::stablehlo::GatherOp; using ::mlir::stablehlo::GetDimensionSizeOp; using ::mlir::stablehlo::ReshapeOp; using ::mlir::stablehlo::UniformQuantizeOp; @@ -79,14 +83,17 @@ bool IsQuantizedTensorType(const Type type) { type.cast().getElementType().isa(); } -// Returns dynamically broadcasted user op of an input op. Returns null if -// the op is not dynamically broadcasted or not the intended type. -// Dynamic shapes usually has the following pattern. In the example below, -// the input operand would be stablehlo.convolution op, and return value would -// be stablehlo.add op. +// Returns broadcasted user op of an input op. Returns null if +// the op is not broadcasted or not the intended type. +// Supports both static broadcast and dynamic broadcast. // Note that the patterns below differ from lifted patterns as // ShapeLegalizeToHloPass is ran prior to running this pass. // +// Dynamically broadcasted bias due to unknown input batch size +// usually has the following pattern. In the example below, +// the input operand would be stablehlo.convolution op, and return value would +// be stablehlo.add op. +// // ``` // %0 = stablehlo.constant dense<3> // %1 = stablehlo.constant dense<4> @@ -100,42 +107,44 @@ bool IsQuantizedTensorType(const Type type) { // %6 = stablehlo.concatenate %5, %0, %1, %2, dim = 0 : // (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) // -> tensor<4xi32> -// %7 = stablehlo.dynamic_broadcast_in_dims %arg2, %6 +// %7 = stablehlo.dynamic_broadcast_in_dim %arg2, %6 // %8 = stablehlo.add %3, %7 // ``` +// +// Statically broadcasted bias will be broadcasted to match the accumulation. +// ``` +// %3 = stablehlo.convolution(%%arg0, %%arg1) : +// (tensor, tensor<2x3x3x2xf32>) -> tensor +// %4 = stablehlo.broadcast_in_dim %arg2, %3 +// %5 = stablehlo.add %3, %4 +// ``` template -Operation* GetDynamicallyBroadcastedUserOp(Operation* op) { - FailureOr get_dimension_size_op = - TryCast(op->getNextNode(), - /*name=*/"get_dimension_size_op"); - if (failed(get_dimension_size_op)) { - return nullptr; - } - auto reshape_op = TryCast((*get_dimension_size_op)->getNextNode(), - /*name=*/"reshape_op"); - if (failed(reshape_op)) { - return nullptr; - } - auto concatenate_op = TryCast((*reshape_op)->getNextNode(), - /*name=*/"concatenate_op"); - if (failed(concatenate_op)) { - return nullptr; +Operation* GetBroadcastedUserOp(Operation* op) { + // Broadcast bias for known input shape. + auto broadcast_in_dim_op = FindUserOfType(op); + if (broadcast_in_dim_op != nullptr) { + auto target_op = FindUserOfType(broadcast_in_dim_op); + if (target_op != nullptr) return target_op; } + // Broadcast bias for unknown input shape. + auto get_dimension_size_op = FindUserOfType(op); + if (get_dimension_size_op == nullptr) return nullptr; + + auto reshape_op = FindUserOfType(get_dimension_size_op); + if (reshape_op == nullptr) return nullptr; + + auto concatenate_op = FindUserOfType(reshape_op); + if (concatenate_op == nullptr) return nullptr; + auto dynamic_broadcast_in_dim_op = - TryCast((*concatenate_op)->getNextNode(), - /*name=*/"dynamic_broadcast_in_dim_op"); - if (failed(dynamic_broadcast_in_dim_op)) { - return nullptr; - } - auto target_op = TryCast((*dynamic_broadcast_in_dim_op)->getNextNode(), - /*name=*/"target_op"); - if (failed(target_op)) { - return nullptr; - } - return *target_op; + FindUserOfType(concatenate_op); + if (dynamic_broadcast_in_dim_op == nullptr) return nullptr; + + auto target_op = FindUserOfType(dynamic_broadcast_in_dim_op); + return target_op; } -// Checks if all inputs and outputs are quantized. +// Checks if one of the inputs and outputs are quantized. bool HasQuantizedOperandOrOutput(Operation* call_op) { SmallVector arg_types; for (const Value arg : call_op->getOperands()) { @@ -147,8 +156,8 @@ bool HasQuantizedOperandOrOutput(Operation* call_op) { output_types.push_back(output.getType()); } - return absl::c_all_of(arg_types, IsQuantizedTensorType) && - absl::c_all_of(output_types, IsQuantizedTensorType); + return absl::c_any_of(arg_types, IsQuantizedTensorType) && + absl::c_any_of(output_types, IsQuantizedTensorType); } // Gets the corresponding quantized function name from the given function name. @@ -161,7 +170,7 @@ std::string GetQuantizedFunctionName(const StringRef func_name) { // Returns true if `xla_call_module_op` is quantized. To be considered // quantized, it should meet three conditions: -// 1. At least one of the inputs or outputs should be a uniform quantized type. +// 1. At least one of the inputs and outputs should be a uniform quantized type. // 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. // 3. It should also have the `kEntryFuncAttrName` attribute, which points to // the function that `xla_call_module_op` represents. @@ -211,6 +220,9 @@ void SetQuantizedFunctionType(PatternRewriter& rewriter, } // Creates a UniformQuantize op and sets it as return op. +// The requantize scale and zero point should be determined from the +// entry_func_op's output, containing information on layerStats of the +// entire function. void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, func::FuncOp entry_func_op, const Type func_result_type) { @@ -226,35 +238,37 @@ template // and sets the quantized bias as the return op. void CreateAndReturnQuantizedBiasPattern( Operation* op, PatternRewriter& rewriter, func::FuncOp entry_func_op, - const Type func_result_type, const Type gemm_style_quantized_element_type, - GemmStyleOp gemm_style_op, double result_scale) { + const Type func_result_type, const Type accumulation_quantized_element_type, + GemmStyleOp gemm_style_op) { Value bias_op = op->getOperand(1); Value add_op_result = op->getResult(0); - // For bias add with dynamic shape, quantize the broadcasted bias. - if (auto dynamic_bcast_op = - cast_or_null(bias_op.getDefiningOp())) { - const UniformQuantizedType dynamic_bcast_quantized_element_type = - CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), - *rewriter.getContext(), result_scale, - /*zero_point=*/0); - - Value dynamic_bcast_op_result = dynamic_bcast_op->getResult(0); - auto dynamic_bcast_op_result_type = - dynamic_bcast_op_result.getType().cast(); - const ArrayRef dynamic_bcast_shape = - dynamic_bcast_op_result_type.getShape(); - - const TensorType new_dynamic_bcast_op_result_type = - dynamic_bcast_op_result_type.cloneWith( - dynamic_bcast_shape, gemm_style_quantized_element_type); - dynamic_bcast_op_result.setType(new_dynamic_bcast_op_result_type); + + // Broadcast bias value if unmatched with output shape. + auto bcast_op = TryCast(bias_op.getDefiningOp(), + /*name=*/"broadcast_in_dim_op"); + + if (failed(bcast_op)) { + bcast_op = TryCast( + bias_op.getDefiningOp(), + /*name=*/"dynamic_broadcast_in_dim_op"); } + // Update the bias type for both static and dynamic broadcasts. + if (succeeded(bcast_op)) { + Value bcast_op_result = (*bcast_op)->getResult(0); + auto bcast_op_result_type = + bcast_op_result.getType().cast(); + const ArrayRef bcast_shape = bcast_op_result_type.getShape(); + const TensorType new_bcast_op_result_type = bcast_op_result_type.cloneWith( + bcast_shape, accumulation_quantized_element_type); + bcast_op_result.setType(new_bcast_op_result_type); + } + const auto add_op_result_type = add_op_result.getType().cast(); const ArrayRef add_op_shape = add_op_result_type.getShape(); // For quantized bias add case, lhs, rhs, and result have the same types. const TensorType new_add_op_result_type = add_op_result_type.cloneWith( - add_op_shape, gemm_style_quantized_element_type); + add_op_shape, accumulation_quantized_element_type); add_op_result.setType(new_add_op_result_type); AddOp bias_add_op = @@ -287,12 +301,14 @@ template LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { auto op_iterator_range = entry_func_op.getOps(); if (op_iterator_range.empty()) { - LLVM_DEBUG(llvm::dbgs() << "Function does not have GemmStyle op.\n"); + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << GemmStyleOp::getOperationName() << " op.\n"); return failure(); } if (!isa( (*op_iterator_range.begin()).getResult().getType())) { - LLVM_DEBUG(llvm::dbgs() << "GemmStyle op must have ranked tensor type.\n"); + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op must have ranked tensor type.\n"); return failure(); } @@ -300,8 +316,8 @@ LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { entry_func_op.getBody().getArguments(); // Function must have input, filter, and optionally bias. if (operands.size() != 2 && operands.size() != 3) { - LLVM_DEBUG(llvm::dbgs() - << "GemmStyle op function should have 2 or 3 operands.\n"); + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op function should have 2 or 3 operands.\n"); return failure(); } return success(); @@ -309,68 +325,144 @@ LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { // Gemm Style Op: glossary/gemm. template -void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { - // Update the output type of the gemm_style op. - GemmStyleOp gemm_style_op = *entry_func_op.getOps().begin(); +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, + bool enable_per_channel_quantized_weight) { + const GemmStyleOp gemm_style_op = + *entry_func_op.getOps().begin(); const Type input_type = entry_func_op.getArgumentTypes()[0]; const Type filter_type = entry_func_op.getArgumentTypes()[1]; const Type func_result_type = entry_func_op.getResultTypes()[0]; - const double input_scale = - getElementTypeOrSelf(input_type).cast().getScale(); - const double filter_scale = - getElementTypeOrSelf(filter_type).cast().getScale(); - const double result_scale = input_scale * filter_scale; - - // Define the intermediate output type, which is an i32 quantized type. - // This is intermediate because the final output type of the entry_func_op - // should be an i8 quantized type. - const UniformQuantizedType gemm_style_quantized_element_type = - CreateI32F32UniformQuantizedType(gemm_style_op->getLoc(), - *rewriter.getContext(), result_scale, - /*zero_point=*/0); - Value gemm_style_op_result = gemm_style_op->getResult(0); - auto gemm_style_op_result_type = + const auto gemm_style_op_result_type = gemm_style_op_result.getType().cast(); const ArrayRef gemm_style_shape = gemm_style_op_result_type.getShape(); - const TensorType new_gemm_style_op_result_type = - gemm_style_op_result_type.cloneWith(gemm_style_shape, - gemm_style_quantized_element_type); + Type accumulation_quantized_element_type; + TensorType new_gemm_style_op_result_type; + + const double input_scale = + getElementTypeOrSelf(input_type).cast().getScale(); + + if (enable_per_channel_quantized_weight) { + ArrayRef filter_scales = getElementTypeOrSelf(filter_type) + .cast() + .getScales(); + std::vector result_scales; + result_scales.reserve(filter_scales.size()); + + for (double filter_scale : filter_scales) { + result_scales.push_back(input_scale * filter_scale); + } + + const ArrayRef zero_points = + getElementTypeOrSelf(filter_type) + .cast() + .getZeroPoints(); + + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + accumulation_quantized_element_type = + CreateI32F32UniformQuantizedPerAxisType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scales, + zero_points, /*quantization_dimension=*/3); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } else { + const double filter_scale = getElementTypeOrSelf(filter_type) + .cast() + .getScale(); + double result_scale = input_scale * filter_scale; + + accumulation_quantized_element_type = CreateI32F32UniformQuantizedType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } + gemm_style_op_result.setType(new_gemm_style_op_result_type); rewriter.setInsertionPointAfter(gemm_style_op); - Operation* next_op = gemm_style_op->getNextNode(); + Operation* next_op = FindUserOfType<>(gemm_style_op); + // If activation exists, omit clipping op. + // Since out_scale and out_zp are computed based on clipped range, + // explicit activation clipping op is not required. if (isa(next_op) && gemm_style_op->hasOneUse()) { // bias fusion CreateAndReturnQuantizedBiasPattern( next_op, rewriter, entry_func_op, func_result_type, - gemm_style_quantized_element_type, gemm_style_op, result_scale); + accumulation_quantized_element_type, gemm_style_op); } else if (auto add_op = cast_or_null( - GetDynamicallyBroadcastedUserOp(gemm_style_op))) { - // dynamic bias fusion + GetBroadcastedUserOp(gemm_style_op))) { + // broadcasted bias fusion rewriter.setInsertionPointAfter(add_op); CreateAndReturnQuantizedBiasPattern( add_op, rewriter, entry_func_op, func_result_type, - gemm_style_quantized_element_type, gemm_style_op, result_scale); + accumulation_quantized_element_type, gemm_style_op); } else { // Non fusible op - // If an op is used multiple times and is not a dynamic shape case, do not - // apply quantization of fused patterns to prevent removal of dependee ops. + // If an op is used multiple times and is not a broadcasted shape case, + // do not apply quantization of fused patterns to prevent removal of + // dependee ops. CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, func_result_type); } } +template +// Match for tensor manipulation op. +LogicalResult MatchSingularOp(func::FuncOp entry_func_op) { + auto op_iterator_range = entry_func_op.getOps(); + if (op_iterator_range.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << SingularOp::getOperationName() << " op.\n"); + return failure(); + } + if (!isa( + (*op_iterator_range.begin()).getResult().getType())) { + LLVM_DEBUG(llvm::dbgs() << SingularOp::getOperationName() + << " op must have ranked tensor type.\n"); + return failure(); + } + return success(); +} + +template +void RewriteSingularOp(func::FuncOp entry_func_op, PatternRewriter& rewriter) { + SingularOp singular_op = *entry_func_op.getOps().begin(); + + const Type operand_type = entry_func_op.getArgumentTypes()[0]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + // Get the quantized tensor manipulation op's output type and update. + Value singular_op_result = singular_op.getResult(); + auto singular_op_result_type = + singular_op_result.getType().cast(); + const ArrayRef singular_op_shape = + singular_op_result_type.getShape(); + const TensorType new_singular_op_result_type = + singular_op_result_type.cloneWith( + singular_op_shape, + getElementTypeOrSelf(operand_type).cast()); + singular_op_result.setType(new_singular_op_result_type); + + // Create requantization op and return. + rewriter.setInsertionPointAfter(singular_op); + CreateAndReturnUniformQuantizeOp(rewriter, *singular_op, entry_func_op, + func_result_type); +} + // Quantizes the entry function's body containing a `DotGeneralOp`. class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeDotGeneralOpPattern() = default; + explicit QuantizeDotGeneralOpPattern( + bool enable_per_channel_quantized_weight) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); @@ -378,14 +470,19 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { void rewrite(func::FuncOp entry_func_op, PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); + RewriteGemmStyleOp( + entry_func_op, rewriter, + /*enable_per_channel_quantized_weight=*/false); } }; // Quantizes the entry function's body containing a `ConvolutionOp`. class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeConvolutionOpPattern() = default; + explicit QuantizeConvolutionOpPattern( + bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); @@ -393,8 +490,32 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { void rewrite(func::FuncOp entry_func_op, PatternRewriter& rewriter) const override { - RewriteGemmStyleOp(entry_func_op, rewriter); + RewriteGemmStyleOp(entry_func_op, rewriter, + enable_per_channel_quantized_weight_); } + + private: + bool enable_per_channel_quantized_weight_; +}; + +// Quantizes the entry function's body containing a `GatherOp`. +class QuantizeGatherOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeGatherOpPattern(bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op) const override { + return MatchSingularOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, + PatternRewriter& rewriter) const override { + RewriteSingularOp(entry_func_op, rewriter); + } + + private: + bool enable_per_channel_quantized_weight_; }; // Converts `entry_func_op` to be quantized according to the respective @@ -453,8 +574,11 @@ template >> class XlaCallModuleOpToCallOp : public OpRewritePattern { public: - explicit XlaCallModuleOpToCallOp(MLIRContext& ctx) - : OpRewritePattern(&ctx) {} + explicit XlaCallModuleOpToCallOp(MLIRContext& ctx, + bool enable_per_channel_quantized_weight) + : OpRewritePattern(&ctx), + enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} LogicalResult match(TF::XlaCallModuleOp op) const override { ModuleOp module_op = op->getParentOfType(); @@ -468,15 +592,19 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { op->emitError("Failed to find a valid entry function."); return failure(); } - return FuncBodyRewritePatternT().match(entry_func_op); + return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op); } void rewrite(TF::XlaCallModuleOp xla_call_module_op, PatternRewriter& rewriter) const override { ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT()); + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_)); } + + private: + bool enable_per_channel_quantized_weight_; }; // Quantizes op with regions such as stablehlo.reduce_window op. @@ -670,7 +798,8 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) { if (type.getElementType().isa()) { return false; } - if (type.getElementType().isa()) { + if (type.getElementType() + .isa()) { has_quantized_types = true; } } @@ -680,7 +809,8 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) { if (type.getElementType().isa()) { return false; } - if (type.getElementType().isa()) { + if (type.getElementType() + .isa()) { has_quantized_types = true; } } @@ -749,9 +879,17 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { // TODO: b/307620428 - Increase fused op coverage for static range quantization. void PopulateFusedGemmStylePatterns(MLIRContext& ctx, - RewritePatternSet& patterns) { - patterns.add, - XlaCallModuleOpToCallOp>(ctx); + RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight) { + patterns.add>( + ctx, enable_per_channel_quantized_weight); + // By default, we set `enable_per_channel_quantized_weight` to true for + // passes to ensure per-channel quantization for all supported ops. + // For ops that do not yet support per-channel quantization, explicitly + // mark as false like below. We will soon add support for per-channel + // quantization of the following ops. + patterns.add>( + ctx, /*enable_per_channel_quantized_weight=*/false); } void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, @@ -759,4 +897,10 @@ void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, patterns.add(ctx); } +void PopulateQuantizeSingularOpPatterns(MLIRContext& ctx, + RewritePatternSet& patterns) { + // TODO: b/307620772 - Per-channel quantization for gather. + patterns.add>( + ctx, /*enable_per_channel_quantized_weight=*/false); +} } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 91170115ce2baa..9922e5bd69eb49 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -124,11 +124,20 @@ class StableHloQuantizationPattern : public RewritePattern { // Const-> QuantizeOp pattern will be handled separately. return failure(); } - if (Operation* quantizing_op = quantize_operand.getDefiningOp()) { + if (Operation* quantizing_op = quantize_operand.getDefiningOp(); + quantizing_op != nullptr) { quantizing_ops.push_back(quantizing_op); + } else { + // When `QuantizeOpT`'s operand does not have a defining op, it means it + // is a `BlockArgument`. The pattern does not match if there is no op to + // quantize. + return failure(); } } + // Safeguard check to ensure that there is at least one quantizable op. + if (quantizing_ops.empty()) return failure(); + absl::flat_hash_set ops_blocklist = quant_params_.quant_spec.ops_blocklist; absl::flat_hash_set nodes_blocklist = @@ -276,15 +285,19 @@ class StableHloQuantizationPattern : public RewritePattern { }; // Gemm Style Op: glossary/gemm. -// Populates conversion patterns to unfuse batch normalization operations. void PopulateFusedGemmStylePatterns(MLIRContext& ctx, - RewritePatternSet& patterns); + RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight); // Populates pattern for quantization of ops with regions such as // stablehlo.reduce_window op. void PopulateQuantizeOpWithRegionPattern(MLIRContext& ctx, RewritePatternSet& patterns); +// Populates conversion patterns for unary data movement ops. +void PopulateQuantizeSingularOpPatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + } // namespace mlir::quant::stablehlo #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 8d321d9269345c..fd5898d686be96 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -44,7 +44,7 @@ namespace mlir::quant::stablehlo { namespace { // Base struct for quantization. -template +template struct StableHloQuantizationBase : public StableHloQuantizationPattern { explicit QuantizePass() = default; - explicit QuantizePass(const QuantizationSpecs& quant_specs) - : quant_specs_(quant_specs) {} + explicit QuantizePass(const QuantizationSpecs& quant_specs, + bool enable_per_channel_quantized_weight) + : quant_specs_(quant_specs), + enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} - QuantizePass(const QuantizePass& other) : quant_specs_(other.quant_specs_) {} + QuantizePass(const QuantizePass& other) + : quant_specs_(other.quant_specs_), + enable_per_channel_quantized_weight_( + other.enable_per_channel_quantized_weight_) {} private: void runOnOperation() override; QuantizationSpecs quant_specs_; + bool enable_per_channel_quantized_weight_; }; void QuantizePass::runOnOperation() { @@ -131,14 +138,14 @@ void QuantizePass::runOnOperation() { patterns.add( &ctx, quant_params); PopulateQuantizeOpWithRegionPattern(ctx, patterns); - PopulateFusedGemmStylePatterns(ctx, patterns); + PopulateFusedGemmStylePatterns(ctx, patterns, + enable_per_channel_quantized_weight_); + PopulateQuantizeSingularOpPatterns(ctx, patterns); if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, // causing this to result in a convergence failure. Consider this as a // best-effort. - // TODO: b/305469508 - Make QuantizationPattern converge if there are no - // patterns that are rewritable. module_op.emitWarning("Failed to converge pattern at QuantizePass."); } } @@ -146,8 +153,10 @@ void QuantizePass::runOnOperation() { } // namespace std::unique_ptr> CreateQuantizePass( - const QuantizationSpecs& quantization_specs) { - return std::make_unique(quantization_specs); + const QuantizationSpecs& quantization_specs, + bool enable_per_channel_quantized_weight) { + return std::make_unique(quantization_specs, + enable_per_channel_quantized_weight); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 026c0742615128..0d491a8cb66404 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/status/status.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -25,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep @@ -50,6 +53,11 @@ class QuantizeCompositeFunctionsPass using impl::QuantizeCompositeFunctionsPassBase< QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase; + explicit QuantizeCompositeFunctionsPass( + bool enable_per_channel_quantized_weight) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + } + private: void runOnOperation() override; }; @@ -65,11 +73,16 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { // (XlaCallModuleOps) with quantized input and output types, which are not // allowed in the TF dialect. pm.enableVerifier(false); - - pm.addNestedPass(CreatePrepareQuantizePass()); + PrepareQuantizePassOptions options; + options.enable_per_channel_quantized_weight_ = + enable_per_channel_quantized_weight_; + // Change this to user-given bit width once we have custom configuration. + options.bit_width_ = 8; + pm.addNestedPass(createPrepareQuantizePass(options)); // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. - pm.addPass(CreateQuantizePass(quant_specs)); + pm.addPass( + CreateQuantizePass(quant_specs, enable_per_channel_quantized_weight_)); pm.addNestedPass(createPostQuantizePass()); ModuleOp module_op = getOperation(); @@ -79,7 +92,13 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { signalPassFailure(); } } - } // namespace +// Creates an instance of the TensorFlow dialect QuantizeCompositeFunctionsPass. +std::unique_ptr> CreateQuantizeCompositeFunctionsPass( + bool enable_per_channel_quantized_weight) { + return std::make_unique( + enable_per_channel_quantized_weight); +} + } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index aa43b64b97b10b..6a152843e3278a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -45,7 +46,6 @@ namespace mlir::quant::stablehlo { namespace { -constexpr StringRef kQuantizeTargetOpAttr = "tf_quant.composite_function"; constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; @@ -73,19 +73,6 @@ class ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass void runOnOperation() override; }; -// Finds the main function from module_op. Returns nullptr if not found. -// The model's signature keys will contain "@serving_default" as default TF -// Model signature, or "@main" if it is in being exported from MLIR module to -// GraphDef. -func::FuncOp GetMainFunc(ModuleOp module_op) { - for (auto func_op : module_op.getOps()) { - if (func_op.getSymName().equals("main") || - func_op.getSymName().equals("serving_default")) - return func_op; - } - return nullptr; -} - // Creates a unique stablehlo function name based on op order. std::string CreateStablehloFunctionName(const int id) { return Twine("_stablehlo_main_").concat(std::to_string(id)).str(); @@ -447,7 +434,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: runOnOperation() { ModuleOp module_op = getOperation(); - func::FuncOp main_func = GetMainFunc(module_op); + func::FuncOp main_func = FindMainFuncOp(module_op); if (!main_func) return; DuplicateSmallConstantOps(module_op, main_func); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td index b7e2abd31d03ba..38d60e94f97e9a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td @@ -35,6 +35,11 @@ def TestPostCalibrationComponentPass : Pass<"stablehlo-test-post-calibration-com let description = [{ Runs the post-calibration passes for post-training quantization. }]; + let options = [ + Option<"unpack_quantized_types_", "unpack-quantized-types", "bool", + /*default=*/"true", "Unpacks ops with uniform quantized types into " + "operations without uniform quantized types (mostly i8 or i32)."> + ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", "mlir::func::FuncDialect", "mlir::mhlo::MhloDialect", @@ -43,3 +48,30 @@ def TestPostCalibrationComponentPass : Pass<"stablehlo-test-post-calibration-com "mlir::quantfork::QuantizationForkDialect", ]; } + +def TestTFToStablehloPass : Pass<"stablehlo-test-tf-to-stablehlo", "mlir::ModuleOp"> { + let summary = "Test-only pass to test TFToStablehloPasses."; + let description = [{ + Runs the TFToStablehloPasses. + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", + "mlir::chlo::ChloDialect", "mlir::quant::QuantizationDialect", + "mlir::mhlo::MhloDialect", "mlir::shape::ShapeDialect", + "mlir::sparse_tensor::SparseTensorDialect", "mlir::vhlo::VhloDialect", + ]; +} + +def TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass : + Pass<"stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs", "mlir::ModuleOp"> { + let summary = "Test-only pass for testing the LiftQuantizableSpotsAsFunctionsPass with a predefined QuantizationSpecs."; + let description = [{ + This test-only pass is the same as `LiftQuantizableSpotsAsFunctionsPass` but + has predefined `QuantizationSpecs` to make FileCheck testing easier. + }]; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::stablehlo::StablehloDialect", + "TF::TensorFlowDialect", + ]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc new file mode 100644 index 00000000000000..e8cb185cb7b55d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::quant::stablehlo::testing { + +// NOLINTNEXTLINE - Automatically generated. +#define GEN_PASS_DEF_TESTLIFTQUANTIZABLESPOTSASFUNCTIONSWITHQUANTIZATIONSPECSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc" + +namespace { + +using ::stablehlo::quantization::QuantizationSpecs; +using ::tsl::protobuf::TextFormat; +// NOLINTNEXTLINE(misc-include-cleaner) - Required for OSS. +using ::tsl::protobuf::io::ArrayInputStream; + +// Configure `QuantizationSpecs` to disable quantization for all dot_general +// quantizable units. +constexpr absl::string_view kSpecsDisableAllDotGeneralByFuncName = + R"pb(specs + [ { + matcher { function_name { regex: "composite_dot_general_.*" } } + method { no_quantization {} } + }])pb"; + +class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass + : public impl:: + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass> { + public: + using impl::TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass>:: + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass) + + private: + void runOnOperation() override; +}; + +// Parses a text proto into a `QuantizationSpecs` proto. Returns +// `InvalidArgumentError` if `text_proto` is invalid. +absl::StatusOr ParseQuantizationSpecsTextProto( + const absl::string_view text_proto) { + QuantizationSpecs quantization_specs; + TextFormat::Parser parser; + ArrayInputStream input_stream(text_proto.data(), text_proto.size()); + if (parser.Parse(&input_stream, &quantization_specs)) { + return quantization_specs; + } + return absl::InvalidArgumentError("Could not parse text proto."); +} + +void TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass:: + runOnOperation() { + PassManager pass_manager{&getContext()}; + + const absl::StatusOr quantization_specs = + ParseQuantizationSpecsTextProto(kSpecsDisableAllDotGeneralByFuncName); + if (!quantization_specs.ok()) { + signalPassFailure(); + return; + } + + pass_manager.addPass( + CreateLiftQuantizableSpotsAsFunctionsPass(*quantization_specs)); + + if (failed(pass_manager.run(getOperation()))) { + signalPassFailure(); + return; + } +} + +} // namespace +} // namespace mlir::quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc index f53d420fe6fc3a..88fa9e59b4977d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc @@ -25,6 +25,7 @@ limitations under the License. #include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" // IWYU pragma: keep @@ -38,10 +39,16 @@ namespace mlir::quant::stablehlo::testing { namespace { +using ::stablehlo::quantization::PipelineConfig; +using ::stablehlo::quantization::StaticRangePtqPreset; + class TestPostCalibrationComponentPass : public impl::TestPostCalibrationComponentPassBase< TestPostCalibrationComponentPass> { public: + using impl::TestPostCalibrationComponentPassBase< + TestPostCalibrationComponentPass>::TestPostCalibrationComponentPassBase; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPostCalibrationComponentPass) private: @@ -54,8 +61,12 @@ void TestPostCalibrationComponentPass::runOnOperation() { OpPassManager pm(ModuleOp::getOperationName()); + StaticRangePtqPreset static_range_ptq_preset; + PipelineConfig pipeline_config; + pipeline_config.set_unpack_quantized_types(unpack_quantized_types_); + PostCalibrationComponent component(&ctx); - component.AddPasses(pm); + component.AddPasses(pm, static_range_ptq_preset, pipeline_config); // Adds a XlaCallModuleOp deserialization pass for easier testing by // inspecting the contents of serialized StableHLO function. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc new file mode 100644 index 00000000000000..3af53a213b0064 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc @@ -0,0 +1,69 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir::quant::stablehlo::testing { + +#define GEN_PASS_DEF_TESTTFTOSTABLEHLOPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc" + +namespace { + +using ::tensorflow::quantization::AddTFToStablehloPasses; +using ::tensorflow::quantization::RunPassesOnModuleOp; + +class TestTFToStablehloPass + : public impl::TestTFToStablehloPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTFToStablehloPass) + + private: + void runOnOperation() override; +}; + +void TestTFToStablehloPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = &getContext(); + mlir::PassManager pm(ctx); + + AddTFToStablehloPasses(pm); + if (!RunPassesOnModuleOp( + /*mlir_dump_file_name=*/"test_tf_to_stablehlo_pass", pm, module_op) + .ok()) { + return signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 3493fdf3cbde92..a91ceec6e151f5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -32,7 +32,6 @@ pytype_strict_library( deps = [ ":pywrap_quantization", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib_py", "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", "//tensorflow/compiler/mlir/quantization/tensorflow/python:save_model", @@ -57,6 +56,7 @@ pytype_strict_library( "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:nn_ops", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:save", "//tensorflow/python/types:core", "//third_party/py/numpy", @@ -67,16 +67,25 @@ pytype_strict_library( tf_py_strict_test( name = "quantize_model_test", srcs = ["integration_test/quantize_model_test.py"], + shard_count = 50, # Parallelize the test to avoid timeouts. deps = [ ":quantization", ":quantize_model_test_base", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_py", "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", + "//tensorflow/python/module", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model:load", + "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:tag_constants", + "//tensorflow/python/types:core", "@absl_py//absl/testing:parameterized", ], ) @@ -86,6 +95,7 @@ tf_python_pybind_extension( srcs = ["pywrap_quantization.cc"], pytype_srcs = ["pywrap_quantization.pyi"], deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config_impl", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters", "@pybind11", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 59d728c23813f8..f359981eaf89d7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== import itertools -from typing import Optional, Sequence +from typing import Mapping, Optional, Sequence from absl.testing import parameterized import numpy as np @@ -22,11 +22,19 @@ from tensorflow.compiler.mlir.quantization.stablehlo.python import quantization from tensorflow.compiler.mlir.quantization.stablehlo.python.integration_test import quantize_model_test_base from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.module import module +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import load +from tensorflow.python.saved_model import save from tensorflow.python.saved_model import tag_constants +from tensorflow.python.types import core def parameter_combinations(test_parameters): @@ -46,24 +54,33 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): @parameterized.parameters( parameter_combinations([{ - 'activation_fn': [None], - 'has_bias': [True, False], - 'dim_sizes': [ + 'bias_fn': ( + None, + nn_ops.bias_add, + ), + 'activation_fn': ( + None, + nn_ops.relu, + nn_ops.relu6, + ), + 'dim_sizes': ( # tf.MatMul cases. ([None, 1024], [1024, 3]), # dynamic batch dim. ([1, 1024], [1024, 3]), # tf.BatchMatMul cases. ([10, 1, 1024], [10, 1024, 3]), ([2, 3, 1, 1024], [2, 3, 1024, 3]), - ], + ), + 'rng_seed': (1230, 1231, 1232, 1233), }]) ) @test_util.run_in_graph_and_eager_modes def test_matmul_ptq_model( self, + bias_fn: Optional[ops.Operation], activation_fn: Optional[ops.Operation], - has_bias: bool, dim_sizes: Sequence[int], + rng_seed: int, ): lhs_dim_size, rhs_dim_size = dim_sizes input_shape = (*lhs_dim_size,) @@ -73,11 +90,11 @@ def test_matmul_ptq_model( input_shape, filter_shape, self._input_saved_model_path, - has_bias, + bias_fn, activation_fn, ) - rng = np.random.default_rng(seed=1235) + rng = np.random.default_rng(rng_seed) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -128,7 +145,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: @parameterized.parameters( parameter_combinations([{ - 'same_scale_op': [ + 'same_scale_op': ( 'concatenate', 'gather', 'max_pool', @@ -137,13 +154,15 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'select', 'slice', 'transpose', - ], + ), + 'rng_seed': (0, 11, 222, 3333), }]) ) @test_util.run_in_graph_and_eager_modes def test_matmul_and_same_scale_ptq_model( self, same_scale_op: str, + rng_seed: int, ): input_shape = (2, 3, 1, 1024) filter_shape = (2, 3, 1024, 3) @@ -156,7 +175,7 @@ def test_matmul_and_same_scale_ptq_model( same_scale_op, ) - rng = np.random.default_rng(seed=1235) + rng = np.random.default_rng(rng_seed) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -205,42 +224,56 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) - @parameterized.named_parameters( - { - 'testcase_name': 'none', - 'activation_fn': None, - 'has_bias': False, - 'has_batch_norm': False, - 'input_shape_dynamic': False, - 'enable_per_channel_quantization': False, - }, + @parameterized.parameters( + parameter_combinations([{ + 'bias_fn': ( + None, + nn_ops.bias_add, + ), + 'activation_fn': ( + None, + nn_ops.relu, + nn_ops.relu6, + ), + 'has_batch_norm': (False,), + 'input_shape_dynamic': ( + False, + True, + ), + 'enable_per_channel_quantized_weight': ( + False, + True, + ), + 'rng_seed': (10, 11, 12, 13), + }]) ) @test_util.run_in_graph_and_eager_modes def test_conv_ptq_model( self, + bias_fn: Optional[ops.Operation], activation_fn: Optional[ops.Operation], - has_bias: bool, has_batch_norm: bool, input_shape_dynamic: bool, - enable_per_channel_quantization: bool, + enable_per_channel_quantized_weight: bool, + rng_seed: int, dilations: Sequence[int] = None, ): - input_shape = (None, None, None, 3) if input_shape_dynamic else (1, 3, 4, 3) + input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) filter_shape = (2, 3, 3, 2) strides = (1, 1, 1, 1) model = self._create_conv2d_model( input_shape, filter_shape, self._input_saved_model_path, - has_bias, - has_batch_norm, + bias_fn, activation_fn, + has_batch_norm, strides, dilations, ) # Generate model input data. - rng = np.random.default_rng(seed=1224) + rng = np.random.default_rng(rng_seed) static_input_shape = [dim if dim is not None else 2 for dim in input_shape] input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( @@ -268,7 +301,8 @@ def data_gen() -> repr_dataset.RepresentativeDataset: qc.RepresentativeDatasetConfig( tf_record=qc.TfRecordFile(path=dataset_path) ) - ] + ], + enable_per_channel_quantized_weight=enable_per_channel_quantized_weight, ), tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), ) @@ -290,10 +324,19 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # values are arbitrary. self.assertAllClose(new_outputs, expected_outputs, rtol=0.02, atol=0.05) - @parameterized.parameters(('abc,cde->abde',), ('abc,dce->abde',)) + @parameterized.parameters( + parameter_combinations([{ + 'equation': ( + 'abc,cde->abde', + 'abc,dce->abde', + ), + 'rng_seed': (82, 82732, 4444, 14), + }]) + ) def test_einsum_ptq_model( self, equation: str, + rng_seed: int, ): _, y_shape, bias_shape, x_signature, y_signature = ( self._prepare_sample_einsum_datashapes(equation, use_bias=True) @@ -309,7 +352,7 @@ def test_einsum_ptq_model( ) # Generate model input data. - rng = np.random.default_rng(seed=1231) + rng = np.random.default_rng(rng_seed) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=x_signature).astype('f4') ) @@ -373,6 +416,239 @@ def test_when_preset_not_srq_raises_error(self): config, ) + @test_util.run_in_graph_and_eager_modes + def test_ptq_denylist_basic(self): + """Tests that the op is not quantized when no quantization is enabled.""" + input_shape = (1, 2) + model = self._create_matmul_model( + input_shape, + weight_shape=(2, 3), + saved_model_path=self._input_saved_model_path, + ) + + rng = np.random.default_rng(1230) + random_tensor_gen_fn = lambda: rng.uniform( + low=0.0, high=1.0, size=input_shape + ).astype(np.float32) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(50): + yield {'input_tensor': random_tensor_gen_fn()} + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ] + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + # Disable quantization for the quantizable unit (lifted function) whose + # function name starts with "composite_dot_general". + specs=qc.QuantizationSpecs( + specs=[ + qc.QuantizationSpec( + matcher=qc.MatcherSpec( + function_name=qc.FunctionNameMatcherSpec( + regex='composite_dot_general.*' + ) + ), + method=qc.Method(no_quantization={}), + ) + ] + ), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + input_data = ops.convert_to_tensor(random_tensor_gen_fn()) + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + + # Indirectly tests that the model is not quantized by asserting that there + # are negligible numeric difference. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.000001) + + @test_util.run_in_graph_and_eager_modes + def test_ptq_selective_denylist(self): + """Tests that the op is not quantized when no quantization is enabled.""" + + rng = np.random.default_rng(1230) + random_tensor_gen_fn = lambda shape: rng.uniform( + low=-1.0, high=1.0, size=shape + ).astype(np.float32) + + class TwoMatmulModel(module.Module): + """A model with two matmul ops.""" + + @def_function.function + def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs a matrix multiplication. + + Args: + input_tensor: Input tensor to matmul with the filter. + + Returns: + A 'output' -> output tensor mapping + """ + out = math_ops.matmul(input_tensor, random_tensor_gen_fn((2, 3))) + out = math_ops.matmul(out, random_tensor_gen_fn((3, 4))) + return {'output': out} + + model = TwoMatmulModel() + input_shape = (1, 2) + + save.save( + model, + self._input_saved_model_path, + signatures=model.matmul.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(50): + yield {'input_tensor': random_tensor_gen_fn(input_shape)} + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ), + ], + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + # Disable quantization for the quantizable unit (lifted function) whose + # function name matches "composite_dot_general_fn_1". + # "composite_dot_general_fn_2" will be quantized. + specs=qc.QuantizationSpecs( + specs=[ + qc.QuantizationSpec( + matcher=qc.MatcherSpec( + function_name=qc.FunctionNameMatcherSpec( + regex='composite_dot_general_fn_1' + ) + ), + method=qc.Method(no_quantization={}), + ) + ] + ), + ) + + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + input_data = ops.convert_to_tensor(random_tensor_gen_fn(input_shape)) + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + + # Indirectly tests that the model is only partially quantized. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.011) + + @test_util.run_in_graph_and_eager_modes + def test_ptq_quantization_method_not_applied_when_matcher_mismatch(self): + """Tests that quantization method is not applied to unmatched units.""" + input_shape = (1, 2) + model = self._create_matmul_model( + input_shape, + weight_shape=(2, 3), + saved_model_path=self._input_saved_model_path, + ) + + rng = np.random.default_rng(1230) + random_tensor_gen_fn = lambda: rng.uniform( + low=0.0, high=1.0, size=input_shape + ).astype(np.float32) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(50): + yield {'input_tensor': random_tensor_gen_fn()} + + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = qc.QuantizationConfig( + static_range_ptq_preset=qc.StaticRangePtqPreset( + representative_datasets=[ + qc.RepresentativeDatasetConfig( + tf_record=qc.TfRecordFile(path=dataset_path) + ) + ] + ), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + specs=qc.QuantizationSpecs( + specs=[ + qc.QuantizationSpec( + # Provide a regex that wouldn't match any quantizable units. + matcher=qc.MatcherSpec( + function_name=qc.FunctionNameMatcherSpec( + regex='.*invalid_function_name.*' + ), + ), + method=qc.Method(no_quantization={}), + ), + ], + ), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + input_data = ops.convert_to_tensor(random_tensor_gen_fn()) + expected_outputs = model.matmul(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + + # Tests that the quantized graph outputs similar values. They also shouldn't + # be exactly the same. Indirectly proves that the `FunctionNameMatcherSpec` + # with regex '.*invalid_function_name.*' did not match the quantizable unit. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.04) + self.assertNotAllClose(new_outputs, expected_outputs, rtol=0.00001) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py index a33ccb1b824e3f..9a5b9225c78a82 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test_base.py @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== """Base test class for quantize_model Tests.""" -from typing import Mapping, Sequence, Optional, Tuple, List + +from typing import List, Mapping, Optional, Sequence, Tuple from absl.testing import parameterized import numpy as np @@ -53,10 +54,8 @@ def _create_matmul_model( input_shape: Sequence[int], weight_shape: Sequence[int], saved_model_path: str, - has_bias: bool = False, + bias_fn: Optional[ops.Operation] = None, activation_fn: Optional[ops.Operation] = None, - bias_size: Optional[int] = None, - use_biasadd: bool = True, ) -> module.Module: class MatmulModel(module.Module): """A simple model with a single matmul. @@ -67,40 +66,28 @@ class MatmulModel(module.Module): def __init__( self, weight_shape: Sequence[int], - bias_size: Optional[int] = None, - activation_fn: Optional[ops.Operation] = None, - use_biasadd: bool = True, ) -> None: """Initializes a MatmulModel. Args: weight_shape: Shape of the weight tensor. - bias_size: If None, do not use bias. Else, use given size as bias. - activation_fn: The activation function to be used. No activation - function if None. - use_biasadd: If True, use BiasAdd for adding bias, else use AddV2. """ - self.bias_size = bias_size - self.activation_fn = activation_fn - self.use_biasadd = use_biasadd self.filters = np.random.uniform(low=-1.0, high=1.0, size=weight_shape) - if bias_size is not None: - self.bias = np.random.uniform(low=-1.0, high=1.0, size=bias_size) - - def has_bias(self) -> bool: - return self.bias_size is not None + if bias_fn is not None: + self.bias = np.random.uniform( + low=-1.0, high=1.0, size=weight_shape[-1] + ) def has_reshape(self) -> bool: - return self.has_bias() and self.bias_size != self.filters.shape[-1] + return self.bias_fn() and self.bias_size != self.filters.shape[-1] @def_function.function def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a matrix multiplication. - Depending on self.has_bias and self.activation_fn, it may add a bias - term or - go through the activaction function. + Depending on self.bias_fn and self.activation_fn, it may add a bias + term or go through the activaction function. Args: input_tensor: Input tensor to matmul with the filter. @@ -109,18 +96,13 @@ def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: A map of: output key -> output result. """ out = math_ops.matmul(input_tensor, self.filters, name='sample/matmul') - + if bias_fn is not None: + out = bias_fn(out, self.bias) + if activation_fn is not None: + out = activation_fn(out) return {'output': out} - # If bias_size is not explictly given, it should default to width of weight. - if bias_size is None and has_bias: - bias_size = weight_shape[-1] - - # Verify that when bias_size is not None, has_bias should be True. - # And if bias_size is None, has_bias should be False. - assert (bias_size is None) != has_bias - - model = MatmulModel(weight_shape, bias_size, activation_fn) + model = MatmulModel(weight_shape) saved_model_save.save( model, saved_model_path, @@ -227,9 +209,9 @@ def _create_conv2d_model( input_shape: Sequence[int], filter_shape: Sequence[int], saved_model_path: str, - has_bias: bool = False, - has_batch_norm: bool = False, + bias_fn: Optional[ops.Operation] = None, activation_fn: Optional[ops.Operation] = None, + has_batch_norm: bool = False, strides: Sequence[int] = (1, 1, 1, 1), dilations: Sequence[int] = (1, 1, 1, 1), padding: str = 'SAME', @@ -277,6 +259,10 @@ def conv2d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: data_format='NHWC', name='sample/conv', ) + if bias_fn is not None: + out = nn_ops.bias_add(out, self.bias) + if activation_fn is not None: + out = activation_fn(out) if has_batch_norm: # Fusing is supported for non-training case. out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc index 6ef921c4260841..bfecd82e21e56f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc @@ -20,6 +20,7 @@ limitations under the License. #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep #include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil #include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep @@ -28,8 +29,9 @@ namespace py = pybind11; namespace { using ::mlir::quant::stablehlo::QuantizeStaticRangePtq; +using ::stablehlo::quantization::PopulateDefaults; -} +} // namespace PYBIND11_MODULE(pywrap_quantization, m) { // Supports absl::Status type conversions. @@ -55,17 +57,24 @@ PYBIND11_MODULE(pywrap_quantization, m) { as defined by the `MetaGraphDef::MetaInfoDef::function_aliases` from the input SavedModel. - `representative_dataset_file_map_serialized` is a signature key -> - `RepresentativeDatasetFile` (serialized) mapping for running the - calibration step. Each dataset file stores the representative dataset - for the function matching the signature key. - Raises `StatusNotOk` exception if when the run was unsuccessful. )pbdoc", py::arg("src_saved_model_path"), py::arg("dst_saved_model_path"), py::arg("quantization_config_serialized"), py::kw_only(), py::arg("signature_keys"), py::arg("signature_def_map_serialized"), - py::arg("function_aliases"), py::arg("py_function_library"), - py::arg("representative_dataset_file_map_serialized")); + py::arg("function_aliases"), py::arg("py_function_library")); + // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) + + // If the function signature changes, likely its corresponding .pyi type + // hinting should also change. + // LINT.IfChange(populate_default_configs) + m.def("populate_default_configs", &PopulateDefaults, + R"pbdoc( + Populates `QuantizationConfig` with default values. + + Returns an updated `QuantizationConfig` (serialized) after populating + default values to fields that the user did not explicitly specify. + )pbdoc", + py::arg("user_provided_config_serialized")); // LINT.ThenChange(pywrap_quantization.pyi:static_range_ptq) } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi index aee4463b238a6b..b3d016465004e6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.pyi @@ -27,8 +27,13 @@ def static_range_ptq( signature_def_map_serialized: dict[str, bytes], function_aliases: dict[str, str], py_function_library: py_function_lib.PyFunctionLibrary, - # Value type: RepresentativeDatasetFile. - representative_dataset_file_map_serialized: dict[str, bytes], ) -> Any: ... # Status # LINT.ThenChange() + +# LINT.IfChange(populate_default_configs) +def populate_default_configs( + user_provided_quantization_config_serialized: bytes, +) -> bytes: ... # QuantizationConfig + +# LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py index 26e38f242af6d9..5e1ce4e7d65ba8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py @@ -17,7 +17,6 @@ from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as qc from tensorflow.compiler.mlir.quantization.stablehlo.python import pywrap_quantization -from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model from tensorflow.core.protobuf import meta_graph_pb2 @@ -45,6 +44,25 @@ def _serialize_signature_def_map( return signature_def_map_serialized +def _populate_default_quantization_config( + config: qc.QuantizationConfig, +) -> qc.QuantizationConfig: + """Populates `QuantizationConfig` with default values. + + Args: + config: User-provided quantization config. + + Returns: + Updated `QuantizationConfig` after populating default values to fields that + the user did not explicitly specify. + """ + pipeline_config = config.pipeline_config + if not pipeline_config.HasField('unpack_quantized_types'): + pipeline_config.unpack_quantized_types = True + + return config + + # TODO: b/310594193 - Export API to pip package. def quantize_saved_model( src_saved_model_path: str, @@ -71,6 +89,10 @@ def quantize_saved_model( ' single signature.' ) + config = qc.QuantizationConfig.FromString( + pywrap_quantization.populate_default_configs(config.SerializeToString()) + ) + signature_def_map = save_model.get_signatures_from_saved_model( src_saved_model_path, signature_keys=None, @@ -82,18 +104,6 @@ def quantize_saved_model( config.tf_saved_model.tags ).meta_info_def.function_aliases - # Create a signature key -> `RepresentativeDatasetFile` mapping. - # `RepresentativeDatsetFile` should be serialized for `static_range_ptq` due - # to pywrap protobuf compatibility requirements. - tfrecord_file_path: str = ( - config.static_range_ptq_preset.representative_datasets[0].tf_record.path - ) - dataset_file_map = { - 'serving_default': quant_opts_pb2.RepresentativeDatasetFile( - tfrecord_file_path=tfrecord_file_path - ).SerializeToString() - } - signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) pywrap_quantization.static_range_ptq( src_saved_model_path, @@ -103,5 +113,4 @@ def quantize_saved_model( signature_def_map_serialized=signature_def_map_serialized, function_aliases=dict(function_aliases), py_function_library=py_function_lib.PyFunctionLibrary(), - representative_dataset_file_map_serialized=dataset_file_map, ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index 89f396061de0b9..25623ad4497655 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -31,11 +31,18 @@ message RepresentativeDatasetConfig { // Minimal user input about representative datasets is required. Representative // datasets are required for static-range PTQ to retrieve quantization // statistics via calibration. -// Next ID: 2 +// Next ID: 3 message StaticRangePtqPreset { // Configures representative dataset. Each item corresponds to a // representative dataset used to calibrate a function. repeated RepresentativeDatasetConfig representative_datasets = 1; + + // NOTE: This field will be deprecated. + // Granularity should be controlled in custom configuration, deprecating + // this field once available. + // If set true, enable channel-wise quantization for all supported ops. + // This value is true by default. + bool enable_per_channel_quantized_weight = 2; } // Metadata specific to the input TensorFlow SavedModel, which may be required @@ -47,10 +54,77 @@ message TfSavedModelConfig { repeated string tags = 1; } +// Configures the graph transformation pipeline for quantization. +message PipelineConfig { + // When set to True, unpacks ops with uniform quantized types into operations + // without uniform quantized types (mostly i8 or i32). Useful when the target + // hardware performs better with integer ops. + // Default value: true + optional bool unpack_quantized_types = 1; +} + +// A quantization method representing "do not quantize". Mostly used for +// denylisting quantizable units from quantization. +message NoQuantization {} + +// Represents a matching method that matches quantizable units by lifted +// functions' names. +message FunctionNameMatcherSpec { + // Regular expression to match lifted functions' names. Underlying regex + // engine uses re2, which accepts a subset of PCRE. See + // https://github.com/google/re2/wiki/Syntax for details. + string regex = 1; +} + +// Matcher specification for identifying quantizable units. +message MatcherSpec { + // Matches lifted functions by their names. + FunctionNameMatcherSpec function_name = 1; +} + +// Specifies how to quantize matched quantizable units. +message Method { + NoQuantization no_quantization = 1; +} + +// A QuantizationSpec is essentially a (matcher spec, quantization method) pair, +// where the matcher spec is used to identify quantizable units and the +// quantization method specifies what type of quantization to apply on the +// matched quantizable units. +// Next ID: 3 +message QuantizationSpec { + // Configures matchers for identifying quantizable units. Matched quantizable + // units will be quantized according to `method`. + MatcherSpec matcher = 1; + + // Specifies how to quantize the matched quantizable units. + Method method = 2; +} + +// Quantization specifications. A simple wrapper around a sequence of +// `QuantizationSpec`s so that specs can be easily passed around or represented +// as a textproto. +// Next ID: 2 +message QuantizationSpecs { + // List of `QuantizationSpec`s. Later spec in the sequence takes precedence. + // + // NOTE: Tie-breaking mechanism is not yet supported. Providing multiple + // `QuantizationSpec` with conflicting quantizable units may result in + // undefined behavior. + // TODO: b/307620778 - Support tie-breaking for conflicting specs. + repeated QuantizationSpec specs = 1; +} + // Quantization configuration for StableHLO Quantizer. This is the primary // message containing all configurable options. -// Next ID: 4 +// Next ID: 5 message QuantizationConfig { + // Config presets provide predefined popular or common quantization specs. + // Lightweight users may choose one of the presets for quick experiments. Each + // preset is completely represented by `QuantizationSpecs`. When extra entries + // in `QuantizationSpecs` are provided along with a preset, then the preset + // will be overridden for the quantizable units matched by those additional + // `QuantizationSpec`s. oneof preset { // Performs best-effort static-range post-training quantization (PTQ). StaticRangePtqPreset static_range_ptq_preset = 1; @@ -58,4 +132,9 @@ message QuantizationConfig { // TF SavedModel specific information for the input model. TfSavedModelConfig tf_saved_model = 2; + + // Configures the graph transformation pipeline for quantization. + PipelineConfig pipeline_config = 3; + + QuantizationSpecs specs = 4; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD index 6fc15864fb0f8b..db4bc1a92483c1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD @@ -46,24 +46,3 @@ tf_cc_test( "@local_tsl//tsl/platform:protobuf", ], ) - -tf_cc_test( - name = "stablehlo_op_quant_spec_test", - srcs = ["stablehlo_op_quant_spec_test.cc"], - deps = [ - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/common:test_base", - "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/core:test", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:QuantOps", - "@stablehlo//:stablehlo_ops", - ], -) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 266b9735224e79..cba7d378fcc190 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -117,10 +117,10 @@ func.func @quantize_per_channel(%arg0: tensor<26x26x3x2xf32> // CHECK-DAG: %[[QMIN:.*]] = mhlo.constant dense<-2.14748365E+9> : tensor // CHECK-DAG: %[[QMAX:.*]] = mhlo.constant dense<2.14748365E+9> : tensor // CHECK: %[[DIVIDE:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> // CHECK: %[[ADD:.*]] = chlo.broadcast_add %[[DIVIDE]], %[[ZPS]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> // CHECK: %[[CLAMP:.*]] = mhlo.clamp %[[QMIN]], %[[ADD]], %[[QMAX]] // CHECK: %[[ROUND:.*]] = mhlo.round_nearest_even %[[CLAMP]] @@ -141,12 +141,12 @@ func.func @dequantize_per_channel( // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[-10, 2]> : tensor<2xi32> // CHECK: %[[SUBTRACT:.*]] = chlo.broadcast_subtract // CHECK-SAME: %[[INPUT:.*]], %[[ZPS]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xi32>, tensor<2xi32>) -> tensor<26x26x3x2xi32> // CHECK: %[[FLOAT:.*]] = mhlo.convert %[[SUBTRACT]] // CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply // CHECK-SAME: %[[FLOAT]], %[[SCALES]] - // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-SAME: {broadcast_dimensions = array} // CHECK-SAME: (tensor<26x26x3x2xf32>, tensor<2xf32>) -> tensor<26x26x3x2xf32> %0 = mhlo.uniform_dequantize %arg0 : ( tensor<26x26x3x2x!quant.uniform> @@ -304,6 +304,78 @@ func.func @add_different_res_type( // ----- +// CHECK-LABEL: func @add_per_channel +func.func @add_per_channel( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: %[[ZPS:.*]] = mhlo.constant dense<[3, 2]> : tensor<2xi32> + // CHECK: %[[BCAST_SUB:.*]] = chlo.broadcast_subtract %[[ADD]], %[[ZPS]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor, tensor<2xi32>) -> tensor + // CHECK: return %[[BCAST_SUB]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +// CHECK-LABEL: func @add_per_channel_no_zp +func.func @add_per_channel_no_zp( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // CHECK: %[[ADD:.*]] = mhlo.add {{.*}} : tensor + // CHECK: return %[[ADD]] : tensor + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_i8( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires i32 storage type}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_different_quant_types( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + +func.func @add_per_channel_per_tensor_mix( + %arg0: tensor>, + %arg1: tensor> + ) -> tensor> { + // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} + %11 = mhlo.add %arg0, %arg1 : ( + tensor>, + tensor> + ) -> tensor> + return %11 : tensor> +} + +// ----- + // CHECK-LABEL: func @requantize func.func @requantize( %arg0: tensor> @@ -351,10 +423,10 @@ func.func @requantize_per_channel( // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 5.000000e-01]> : tensor<2xf32> // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -2.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] @@ -375,10 +447,10 @@ func.func @requantize_per_channel_to_per_tensor( // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -1.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] @@ -399,10 +471,10 @@ func.func @requantize_per_tensor_to_per_channel( // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-1.000000e+00, -2.000000e+00]> : tensor<2xf32> // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] - // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: broadcast_dimensions = array // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir index 7a568425415170..69f3f50b64ded6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert_tf_quant_ops_to_mhlo.mlir @@ -40,7 +40,7 @@ func.func @uniform_quantized_add(%input: tensor<3x2xf32>) -> tensor<3x2xf32> { quantization_axis = -1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64 } : (tensor<3x2xf32>, tensor, tensor) -> tensor<3x2x!tf_type.qint32> - // CHECK: chlo.broadcast_add %[[LHS2]], %[[RHS]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: chlo.broadcast_add %[[LHS2]], %[[RHS]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> %1 = "tf.UniformQuantizedAdd"( @@ -85,7 +85,7 @@ func.func @uniform_quantized_add_bias_not_const(%input1: tensor<3x2xi32>, %input %input1_qint = "tf.Cast"(%input1) {Truncate = false} : (tensor<3x2xi32>) -> tensor<3x2x!tf_type.qint32> %input2_qint = "tf.Cast"(%input2) {Truncate = false} : (tensor<2xi32>) -> tensor<2x!tf_type.qint32> - // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS_2]], %[[RHS_2]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS_2]], %[[RHS_2]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> %result = "tf.UniformQuantizedAdd"( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir similarity index 61% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/post_calibration_component.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir index e97e5b1d11cfa3..aa11f88937e913 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/post_calibration_component.mlir @@ -1,4 +1,7 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-test-post-calibration-component | FileCheck %s +// RUN: stablehlo-quant-opt %s -stablehlo-test-post-calibration-component \ +// RUN: -split-input-file | FileCheck %s +// RUN: stablehlo-quant-opt %s -stablehlo-test-post-calibration-component='unpack-quantized-types=false' \ +// RUN: -split-input-file | FileCheck %s --check-prefix=CHECK-NO-UNPACK // Tests that a simple dot_general (lifted as a function) with CustomAggregators // around it is quantized. The resulting graph has quantized types unpacked into @@ -31,6 +34,33 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: // ----- +// Tests that a simple dot_general (lifted as a function) with CustomAggregators +// around it is quantized, when the 'unpack-quantized-types' option is set to +// false. This test case inputs the same graph as the test above. Note that now +// the uniform quantized types are directly expressed within the graph. + +func.func @main_no_unpack(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { + %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> + %1 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + %2 = "tf.XlaCallModule"(%1, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 18.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -17.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> +} +func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} +// CHECK-NO-UNPACK-LABEL: func.func @main_no_unpack +// CHECK-NO-UNPACK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32> +// CHECK-NO-UNPACK-DAG: %[[CONST:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, {{.*}}>> +// CHECK-NO-UNPACK: %[[QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x1024xf32>) -> tensor<1x1024x!quant.uniform> +// CHECK-NO-UNPACK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[QUANTIZE_0]], %[[CONST]], contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-NO-UNPACK: %[[QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-NO-UNPACK: %[[DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[QUANTIZE_1]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-NO-UNPACK: return %[[DEQUANTIZE]] : tensor<1x3xf32> + +// ----- + // Tests that a simple dot_general without CustomAggregators is not quantized. func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/pre_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/pre_calibration_component.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir new file mode 100644 index 00000000000000..09afb528f602aa --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/tf_to_stablehlo.mlir @@ -0,0 +1,59 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics -stablehlo-test-tf-to-stablehlo | FileCheck %s + +func.func @fused_batchnorm_no_training(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) { + %cst_0 = "tf.Const"() {value = dense<[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2]> : tensor<8xf32>} : () -> tensor<8xf32> + %cst_1 = "tf.Const"() {value = dense<[0.3, 0.4, 0.3, 0.4, 0.3, 0.4, 0.3, 0.4]> : tensor<8xf32>} : () -> tensor<8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %cst_0, %cst_1, %cst_0, %cst_1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} +// CHECK: func.func @main(%[[ARG_0:.+]]: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<{{.*}}> : tensor<8xf32> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<{{.*}}> : tensor<8xf32> +// CHECK: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %[[CONST_1]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK: %[[MUL:.*]] = stablehlo.multiply %arg0, %[[BROADCAST_0]] : tensor<8x8x8x8xf32> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[MUL]], %[[BROADCAST_1]] : tensor<8x8x8x8xf32> +// CHECK: return %[[ADD]] : tensor<8x8x8x8xf32> + +// ----- + +func.func @fuse_conv_batchnorm(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_1 = "tf.Const"() {value = dense<[0.1, 0.2]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_2 = "tf.Const"() {value = dense<[0.3, 0.4]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst_0) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1:6 = "tf.FusedBatchNormV3"(%0, %cst_1, %cst_2, %cst_1, %cst_2) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + func.return %1#0 : tensor<1x3x2x2xf32> +} +// CHECK: func.func @main(%[[ARG:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> { +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [3] : (tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %[[CONST_1]], dims = [3] : (tensor<2xf32>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[ARG]], %[[BROADCAST_1]]) {{.*}} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_0]] : tensor<1x3x2x2xf32> +// CHECK: return %[[ADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @func_conv_batchnorm_relu6(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_1 = "tf.Const"() {value = dense<[0.1, 0.2]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_2 = "tf.Const"() {value = dense<[0.3, 0.4]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst_0) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1:6 = "tf.FusedBatchNormV3"(%0, %cst_1, %cst_2, %cst_1, %cst_2) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + %2 = "tf.Relu6"(%1#0) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func.func @main(%[[ARG:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> { +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<[{{.*}}]> : tensor<2xf32> +// CHECK-DAG: %[[CONST_2:.*]] = stablehlo.constant dense<6.000000e+00> : tensor +// CHECK-DAG: %[[CONST_3:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [3] : (tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %[[CONST_1]], dims = [3] : (tensor<2xf32>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[ARG]], %[[BROADCAST_1]]) {{.*}} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_0]] : tensor<1x3x2x2xf32> +// CHECK: %[[RELU6:.*]] = stablehlo.clamp %[[CONST_3]], %[[ADD]], %[[CONST_2]] : (tensor, tensor<1x3x2x2xf32>, tensor) -> tensor<1x3x2x2xf32> +// CHECK: return %[[RELU6]] : tensor<1x3x2x2xf32> + diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_func_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_func_to_bfloat16.mlir deleted file mode 100644 index 6923209e531da9..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_func_to_bfloat16.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-convert-func-to-bfloat16 -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: @add_f32(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> -func.func @add_f32(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { - // CHECK-NOT: f32 - // CHECK: stablehlo.add - %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - return %0 : tensor<3x3xf32> -} - -// ----- - -// CHECK-LABEL: @add_f64(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> -func.func @add_f64(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK-NOT: f64 - // CHECK: stablehlo.add - %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf64>, tensor<3x3xf64>) -> tensor<3x3xf64> - return %0 : tensor<3x3xf64> -} - -// ----- - -// CHECK-LABEL: @constant_f32() -> tensor<2x2xbf16> -func.func @constant_f32() -> tensor<2x2xf32> { - // CHECK-NOT: f32 - // CHECK{LITERAL}: stablehlo.constant dense<[[1.398440e+00, 0.000000e+00], [3.093750e+00, -2.001950e-01]]> : tensor<2x2xbf16> - %0 = stablehlo.constant dense<[[1.4, 0.0], [3.1, -0.2]]> : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// ----- - -func.func @constant_elided() -> tensor<2x2xf32> { - // expected-error @+1 {{failed to legalize operation 'stablehlo.constant' that was explicitly marked illegal}} - %0 = stablehlo.constant dense_resource<__elided__> : tensor<2x2xf32> - return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: @reduce_window_f32(%arg0: tensor<2x3x1x3xbf16>) -> tensor<2x3x1x3xbf16> -func.func @reduce_window_f32(%arg0: tensor<2x3x1x3xf32>) -> tensor<2x3x1x3xf32> { - // CHECK-NOT: f32 - // CHECK: stablehlo.reduce_window - %0 = stablehlo.constant dense<0.0> : tensor - %1 = "stablehlo.reduce_window"(%arg0, %0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %2 = stablehlo.maximum %arg1, %arg2 : tensor - stablehlo.return %2 : tensor - }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> - return %1 : tensor<2x3x1x3xf32> -} - -// ----- - diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_func_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_func_to_bfloat16.mlir new file mode 100644 index 00000000000000..fdb5860eb1bd23 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_func_to_bfloat16.mlir @@ -0,0 +1,128 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-convert-func-to-bfloat16 -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @add_f32(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f32(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> +} + +// ----- + +// CHECK-LABEL: @add_f64(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f64(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64> { + // CHECK-NOT: f64 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf64>, tensor<3x3xf64>) -> tensor<3x3xf64> + return %0 : tensor<3x3xf64> +} + +// ----- + +// CHECK-LABEL: @constant_f32() -> tensor<2x2xbf16> +func.func @constant_f32() -> tensor<2x2xf32> { + // CHECK-NOT: f32 + // CHECK{LITERAL}: stablehlo.constant dense<[[1.398440e+00, 0.000000e+00], [3.093750e+00, -2.001950e-01]]> : tensor<2x2xbf16> + %0 = stablehlo.constant dense<[[1.4, 0.0], [3.1, -0.2]]> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @constant_elided() -> tensor<2x2xf32> { + // expected-error @+1 {{failed to legalize operation 'stablehlo.constant' that was explicitly marked illegal}} + %0 = stablehlo.constant dense_resource<__elided__> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @reduce_window_f32(%arg0: tensor<2x3x1x3xbf16>) -> tensor<2x3x1x3xbf16> +func.func @reduce_window_f32(%arg0: tensor<2x3x1x3xf32>) -> tensor<2x3x1x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.reduce_window + %0 = stablehlo.constant dense<0.0> : tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %2 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %2 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> + return %1 : tensor<2x3x1x3xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_i32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xi32> +func.func @bitcast_convert_f32_i32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xi32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xi32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + return %20 : tensor<1x256128xi32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xui32> +func.func @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xui32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_f32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_f32_f32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xf32> { + // Convert bitcast_convert to no-op for f32->f32. + // CHECK: return %arg0 : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> +func.func @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + // CHECK: return %[[BITCAST]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + return %20 : tensor<1x256128xbf16> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_xla_call_module_op_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_xla_call_module_op_to_bfloat16.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/convert_xla_call_module_op_to_bfloat16.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/convert_xla_call_module_op_to_bfloat16.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir similarity index 85% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir index 052d618a2e147c..e743a19dc0a822 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/lift_quantizable_spots_as_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions.mlir @@ -38,14 +38,38 @@ func.func @dot_general_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { // ----- +// CHECK-LABEL: @dot_general_with_bias_same_shape_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x2xf32> +func.func @dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %3 = stablehlo.add %2, %1 : tensor<1x3xf32> + func.return %3: tensor<1x3xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: return %[[ADD]] : tensor<1x3xf32> +// CHECK: } + +// ----- + // CHECK-LABEL: @conv_with_bias_fn( // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> func.func @conv_with_bias_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> - %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> - %3 = stablehlo.add %2, %1 : tensor<1x3x3x4xf32> - func.return %3: tensor<1x3x3x4xf32> + %3 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %4 = stablehlo.add %2, %3 : tensor<1x3x3x4xf32> + func.return %4: tensor<1x3x3x4xf32> } // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> @@ -54,8 +78,9 @@ func.func @conv_with_bias_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> // CHECK: } // CHECK-LABEL: private @composite_conv_with_bias_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) -// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %arg2 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] // CHECK: return %[[ADD]] : tensor<1x3x3x4xf32> // CHECK: } @@ -65,10 +90,11 @@ func.func @conv_with_bias_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> func.func @dot_general_with_bias_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> - %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> - %3 = stablehlo.add %2, %1 : tensor<1x1x64xf32> - func.return %3: tensor<1x1x64xf32> + %3 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.add %2, %3 : tensor<1x1x64xf32> + func.return %4: tensor<1x1x64xf32> } // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> @@ -77,8 +103,9 @@ func.func @dot_general_with_bias_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64 // CHECK: } // CHECK-LABEL: private @composite_dot_general_with_bias_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 -// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] // CHECK: return %[[ADD]] : tensor<1x1x64xf32> // CHECK: } @@ -361,16 +388,44 @@ func.func @dot_general_with_relu6_dynamic_fn(%arg0: tensor) -> tens // ----- +// CHECK-LABEL: @dot_general_with_bias_same_shape_and_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_same_shape_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.add %3, %1 : tensor<1x1x64xf32> + %5 = stablehlo.maximum %4, %2 : tensor<1x1x64xf32> + func.return %5: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu_fn_1 +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + // CHECK-LABEL: @conv_with_bias_and_relu_fn( // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> - %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> - %4 = stablehlo.add %3, %1 : tensor<1x3x3x4xf32> - %5 = stablehlo.maximum %4, %2 : tensor<1x3x3x4xf32> - func.return %5: tensor<1x3x3x4xf32> + %4 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %5 = stablehlo.add %3, %4 : tensor<1x3x3x4xf32> + %6 = stablehlo.maximum %5, %2 : tensor<1x3x3x4xf32> + func.return %6: tensor<1x3x3x4xf32> } // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> @@ -380,8 +435,9 @@ func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x // CHECK-LABEL: private @composite_conv_with_bias_and_relu_fn_1 // CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) -// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %arg2 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] // CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> // CHECK: } @@ -392,12 +448,13 @@ func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> func.func @dot_general_with_bias_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> - %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> - %4 = stablehlo.add %3, %1 : tensor<1x1x64xf32> - %5 = stablehlo.maximum %4, %2 : tensor<1x1x64xf32> - func.return %5: tensor<1x1x64xf32> + %4 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.add %3, %4 : tensor<1x1x64xf32> + %6 = stablehlo.maximum %5, %2 : tensor<1x1x64xf32> + func.return %6: tensor<1x1x64xf32> } // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> @@ -407,8 +464,9 @@ func.func @dot_general_with_bias_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tens // CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_fn_1 // CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 -// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] // CHECK: return %[[MAX]] : tensor<1x1x64xf32> // CHECK: } @@ -485,17 +543,47 @@ func.func @dot_general_with_bias_and_relu_dynamic_fn(%arg0: tensor) // ----- +// CHECK-LABEL: @dot_general_with_bias_same_shape_and_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_same_shape_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> + %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.add %4, %1 : tensor<1x1x64xf32> + %6 = stablehlo.clamp %2, %5, %3 : tensor<1x1x64xf32> + func.return %6: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu6_fn_1 +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + // CHECK-LABEL: @conv_with_bias_and_relu6_fn( // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> - %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32> %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> - %5 = stablehlo.add %4, %1 : tensor<1x3x3x4xf32> - %6 = stablehlo.clamp %2, %5, %3 : tensor<1x3x3x4xf32> - func.return %6: tensor<1x3x3x4xf32> + %5 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %6 = stablehlo.add %4, %5 : tensor<1x3x3x4xf32> + %7 = stablehlo.clamp %2, %6, %3 : tensor<1x3x3x4xf32> + func.return %7: tensor<1x3x3x4xf32> } // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> @@ -506,8 +594,9 @@ func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3 // CHECK-LABEL: private @composite_conv_with_bias_and_relu6_fn_1 // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) -// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %arg2 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor<1x3x3x4xf32> // CHECK: } @@ -518,13 +607,14 @@ func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3 // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> func.func @dot_general_with_bias_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> - %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> - %5 = stablehlo.add %4, %1 : tensor<1x1x64xf32> - %6 = stablehlo.clamp %2, %5, %3 : tensor<1x1x64xf32> - func.return %6: tensor<1x1x64xf32> + %5 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %6 = stablehlo.add %4, %5 : tensor<1x1x64xf32> + %7 = stablehlo.clamp %2, %6, %3 : tensor<1x1x64xf32> + func.return %7: tensor<1x1x64xf32> } // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> @@ -535,8 +625,9 @@ func.func @dot_general_with_bias_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> ten // CHECK-LABEL: private @composite_dot_general_with_bias_and_relu6_fn_1 // CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> // CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 // CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 -// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] // CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> // CHECK: } @@ -619,7 +710,7 @@ func.func @gather_fn() -> tensor<2x3x2x2xi32> { collapsed_slice_dims = [0], start_index_map = [1, 0], index_vector_dim = 2>, - slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + slice_sizes = array, indices_are_sorted = false } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> func.return %2: tensor<2x3x2x2xi32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir new file mode 100644 index 00000000000000..00b3dd3b5e57a4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir @@ -0,0 +1,25 @@ +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs \ +// RUN: -split-input-file | FileCheck %s + +// CHECK: @main +func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + return %1 : tensor<1x1x64xf32> +} +// Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp +// is missing attributes required for quantization. + +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK-SAME: {_entry_function = @composite_dot_general_fn_1, {{.*}}} +// CHECK-NOT: _original_entry_function +// CHECK-NOT: _tfl_quant_trait +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_fn_1 +// CHECK-NOT: tf_quant.composite_function +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/optimize_graph.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/optimize_graph.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/optimize_graph.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/optimize_graph.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/post_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/post_quantize.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/post_quantize.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize.mlir similarity index 91% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize.mlir index 67a059c6061ceb..7634d81e976ca1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantization=false -verify-diagnostics | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantized-weight=false -verify-diagnostics | FileCheck %s // ----- @@ -74,17 +74,17 @@ func.func @dot_redundant_stats(%arg0: tensor) -> tensor { // ----- -// CHECK-LABEL: func @convert_same_scale_propagate -func.func @convert_same_scale_propagate(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { +// CHECK-LABEL: func @reshape_same_scale_propagate +func.func @reshape_same_scale_propagate(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { // CHECK: %[[dq:.*]] = "quantfork.dcast" // CHECK-SAME: (tensor<2x3x!quant.uniform>) %0 = "quantfork.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK: %[[convert:.*]] = stablehlo.convert %[[dq]] - %1 = stablehlo.convert %0 : (tensor<2x3xf32>) -> (tensor<2x3xf32>) - // CHECK: %[[q:.*]] = "quantfork.qcast"(%[[convert]]) - // CHECK-SAME: -> tensor<2x3x!quant.uniform> - %2 = "quantfork.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %2 : tensor<2x3xf32> + // CHECK: %[[reshape:.*]] = stablehlo.reshape %[[dq]] + %1 = stablehlo.reshape %0 : (tensor<2x3xf32>) -> (tensor<6xf32>) + // CHECK: %[[q:.*]] = "quantfork.qcast"(%[[reshape]]) + // CHECK-SAME: -> tensor<6x!quant.uniform> + %2 = "quantfork.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<6xf32>) -> tensor<6xf32> + func.return %2 : tensor<6xf32> } // ----- @@ -135,6 +135,6 @@ func.func @skip_nan_inf_constant(%arg0: tensor) -> tensor, %arg2: tensor): %7 = stablehlo.maximum %arg1, %arg2 : tensor stablehlo.return %7 : tensor - }) {padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor, tensor) -> tensor + }) {padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor return %6 : tensor } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_int4.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_int4.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_int4.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_per_channel.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir similarity index 98% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_per_channel.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir index f509ffce05863d..a6159c1dd62b4b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize_per_channel.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/prepare_quantize_per_channel.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantization=true -verify-diagnostics | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-prepare-quantize=enable-per-channel-quantized-weight=true -verify-diagnostics | FileCheck %s // ----- diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_op_with_region.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir similarity index 95% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_op_with_region.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir index 747e5ee4188757..04104c308a3b3d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_op_with_region.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_op_with_region.mlir @@ -18,7 +18,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - // CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64> + // CHECK-SAME: window_dimensions = array // CHECK-SAME: (tensor<2x3x1x3x!quant.uniform>, tensor>) -> tensor<2x3x1x3x!quant.uniform> // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[REDUCE]]) @@ -39,7 +39,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p ^bb0(%arg1: tensor, %arg2: tensor): %14 = stablehlo.maximum %arg1, %arg2 : tensor stablehlo.return %14 : tensor - }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> %12 = "quantfork.qcast"(%11) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> %13 = "quantfork.dcast"(%12) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> return %13 : tensor<2x3x1x3xf32> @@ -74,7 +74,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> - // CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64> + // CHECK-SAME: window_dimensions = array // CHECK-SAME: (tensor<2x3x1x1024x!quant.uniform>, tensor>) -> tensor<2x3x1x1024x!quant.uniform> // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[CST1]]) @@ -93,7 +93,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p ^bb0(%arg1: tensor, %arg2: tensor): %14 = stablehlo.maximum %arg1, %arg2 : tensor stablehlo.return %14 : tensor - }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>} : (tensor<2x3x1x1024xf32>, tensor) -> tensor<2x3x1x1024xf32> + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x1024xf32>, tensor) -> tensor<2x3x1x1024xf32> %7 = "quantfork.qcast"(%6) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> %8 = "quantfork.dcast"(%7) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> %9 = "quantfork.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> @@ -136,7 +136,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> - // CHECK-SAME: window_dimensions = dense<[1, 3, 1]> : tensor<3xi64> + // CHECK-SAME: window_dimensions = array // CHECK-SAME: (tensor<2x3x3x!quant.uniform>, tensor>) -> tensor<2x3x3x!quant.uniform> // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[REDUCE]]) @@ -160,7 +160,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p ^bb0(%arg1: tensor, %arg2: tensor): %17 = stablehlo.maximum %arg1, %arg2 : tensor stablehlo.return %17 : tensor - }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = dense<[1, 3, 1]> : tensor<3xi64>} : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> + }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = array} : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> %15 = "quantfork.qcast"(%14) {volatile} : (tensor<2x3x3xf32>) -> tensor<2x3x3x!quant.uniform> %16 = "quantfork.dcast"(%15) : (tensor<2x3x3x!quant.uniform>) -> tensor<2x3x3xf32> return %16 : tensor<2x3x3xf32> @@ -195,7 +195,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> // CHECK: stablehlo.return %[[MAX]] : tensor> // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> - // CHECK-SAME: window_dimensions = dense<[1, 3, 1]> : tensor<3xi64> + // CHECK-SAME: window_dimensions = array // CHECK-SAME: (tensor<2x3x1024x!quant.uniform>, tensor>) -> tensor<2x3x1024x!quant.uniform> // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[REDUCE]] @@ -215,7 +215,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p ^bb0(%arg1: tensor, %arg2: tensor): %17 = stablehlo.maximum %arg1, %arg2 : tensor stablehlo.return %17 : tensor - }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = dense<[1, 3, 1]> : tensor<3xi64>} : (tensor<2x3x1024xf32>, tensor) -> tensor<2x3x1024xf32> + }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = array} : (tensor<2x3x1024xf32>, tensor) -> tensor<2x3x1024xf32> %7 = "quantfork.qcast"(%6) {volatile} : (tensor<2x3x1024xf32>) -> tensor<2x3x1024x!quant.uniform> %8 = "quantfork.dcast"(%7) : (tensor<2x3x1024x!quant.uniform>) -> tensor<2x3x1024xf32> %9 = stablehlo.reshape %8 : (tensor<2x3x1024xf32>) -> tensor<2x3x1x1024xf32> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir similarity index 90% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir index d149ab12324f52..f437016ed2c6f2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_same_scale.mlir @@ -154,46 +154,6 @@ module attributes {tf_saved_model.semantics} { // ----- -module attributes {tf_saved_model.semantics} { - // CHECK-LABEL: composite_and_convert - // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> - // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> - func.func private @composite_and_convert(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> { - // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>> - // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) - // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<1x3x!quant.uniform> - // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[CALL]] : tensor<1x3x!quant.uniform> - // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CONVERT]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - // CHECK: return %[[DQ]] - %0 = "quantfork.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - %1 = "quantfork.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> - %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>> - %3 = "quantfork.dcast"(%2) : (tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<2x3xf32> - %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %5 = "quantfork.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> - %6 = "quantfork.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - %7 = stablehlo.convert %6 : (tensor<1x3xf32>) -> tensor<1x3xf32> - %8 = "quantfork.qcast"(%7) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> - %9 = "quantfork.dcast"(%8) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> - return %9 : tensor<1x3xf32> - } - - // CHECK: quantized_dot_general_fn_1 - // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> - // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>> - func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] - // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, 6.000000e-03:13>>) -> tensor<1x3x!quant.uniform> - // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> - // CHECK: return %[[Q3]] - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -} - -// ----- - module attributes {tf_saved_model.semantics} { // CHECK-LABEL: composite_and_pad // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> @@ -351,7 +311,7 @@ module attributes {tf_saved_model.semantics} { collapsed_slice_dims = [0], start_index_map = [1, 0], index_vector_dim = 2>, - slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + slice_sizes = array, indices_are_sorted = false } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> %8 = "quantfork.qcast"(%7) {volatile} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2x!quant.uniform> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir new file mode 100644 index 00000000000000..ff7bfafe654099 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir @@ -0,0 +1,567 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -stablehlo-quantize-composite-functions | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -stablehlo-quantize-composite-functions=enable-per-channel-quantized-weight=false | FileCheck --check-prefix=CHECK-PER-TENSOR %s + +// Tests that basic dot_general is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Checks that the entry function is quantized for dot_general. Quantized +// dot_general outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for dot_general + bias is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_same_shape_fn, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_same_shape_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +} + +// ----- + +// Tests that fused pattern for dot_general + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<3xf32>, tensor<2xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor> + + +// ----- + +// Tests that basic convolution is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64, _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// Check that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for convolution + bias is properly quantized. + +// Checks that fused functions with 1D bias is properly quantized. +// The 1D bias should be broadcasted in dims [3], where it initially has +// `quantizedDimension=0`, but has `quantizedDimension=3` after broadcasting. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_1d_fn, _original_entry_function = "composite_conv_with_bias_1d_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<47978> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %[[ARG_3]] +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Checks that fused functions with 4D bias is properly quantized. +// The 4D bias should be braoadcasted in dims [0, 1, 2, 3], where it +// already has `quantizedDimension=3`. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_fn, _original_entry_function = "composite_conv_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2 +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for convolution + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_dynamic_fn, _original_entry_function = "composite_conv_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.maximum. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[0.00000000e-6, 8.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu6 with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.clamp. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_and_relu6_dynamic_fn, _original_entry_function = "composite_conv_with_bias_and_relu6_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[XLA_CALL_MODULE_0]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2xf32>, %[[ARG_2:.+]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// Check that the composite_dot_general_fn is untouched. +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]] +// CHECK: return %[[DOT_GENERAL_0]] +} + +// ----- + +// Tests that basic gather is properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func private @quantize_gather_fn(%[[ARG:.+]]: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_gather_fn(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> + %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _original_entry_function = "composite_gather_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> + return %2 : tensor<2x3x2x2xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. +// CHECK: %[[CONST:.+]] = stablehlo.constant dense<{{.*}}> : tensor<2x3x2xi32> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_gather_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<2x3x2x2x!quant.uniform) -> tensor<2x3x2x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE]] : tensor<2x3x2x2xf32> + +// CHECK: func.func private @quantized_gather_fn(%[[ARG_0:.+]]: tensor<3x4x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_gather_fn(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> attributes {_from_xla_call_module} { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> + } +// CHECK: %[[GATHER:.+]] = "stablehlo.gather"(%[[ARG_0]], %[[ARG_1]]) {{.*}} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<2x3x2x2x!quant.uniform> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/restore_function_name.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/restore_function_name.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/restore_function_name.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/restore_function_name.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/unfuse_mhlo_batch_norm.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unfuse_mhlo_batch_norm.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/unwrap_xla_call_module_op.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unwrap_xla_call_module_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/tests/unwrap_xla_call_module_op.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/unwrap_xla_call_module_op.mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir deleted file mode 100644 index 779ef786714fb7..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/populate_shape.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -populate-shape | FileCheck %s - -// CHECK-LABEL: @populate_shape_for_custom_aggregator -func.func @populate_shape_for_custom_aggregator(%input: tensor) { - // CHECK: %[[OUTPUT:.*]] = "tf.CustomAggregator"(%[[INPUT:.*]]) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor - %0 = "tf.CustomAggregator"(%input) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor<*xf32> - func.return -} - -// ---- - -// CHECK-LABEL: @populate_shape_for_xla_call_module -func.func @populate_shape_for_xla_call_module(%input: tensor) { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> - // CHECK: %[[OUTPUT:.*]] = "tf.XlaCallModule"(%[[INPUT:.*]], %[[CST:.*]]) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor - %0 = "tf.XlaCallModule"(%input, %cst) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor<*xf32> - func.return -} - -// ---- - -// CHECK-LABEL: @populate_shape_for_chain_of_ops -func.func @populate_shape_for_chain_of_ops(%input: tensor) { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> - // CHECK: %[[VAL_0:.*]] = "tf.CustomAggregator"(%[[INPUT:.*]]) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor - // CHECK: %[[VAL_1:.*]] = "tf.XlaCallModule"(%[[VAL_0:.*]], %[[CST:.*]]) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> tensor - // CHECK: %[[VAL_2:.*]] = "tf.CustomAggregator"(%[[VAL_1:.*]]) <{id = "49d53b1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor - %0 = "tf.CustomAggregator"(%input) <{id = "49d53b0"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor) -> tensor<*xf32> - %1 = "tf.XlaCallModule"(%0, %cst) <{Sout = [#tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<*xf32>, tensor<1x1x64x256xf32>) -> tensor<*xf32> - %2 = "tf.CustomAggregator"(%1) <{id = "49d53b1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 6.000000e+00 : f32, max_percentile = 0.000000e+00 : f32, min = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - func.return -} - -// ---- - -// CHECK-LABEL: @populate_shape_for_xla_call_module_failure_not_single_output -func.func @populate_shape_for_xla_call_module_failure_not_single_output(%input: tensor) { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<1x1x64x256xf32>} : () -> tensor<1x1x64x256xf32> - // expected-error @+2 {{XlaCallModuleOp doesn't have 1 output.}} - %0, %1 = "tf.XlaCallModule"(%input, %cst) <{Sout = [#tf_type.shape, #tf_type.shape], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> (tensor<*xf32>, tensor<*xf32>) - // expected-error @+1 {{XlaCallModuleOp doesn't have 1 output.}} - "tf.XlaCallModule"(%input, %cst) <{Sout = [], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @main_9, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor, tensor<1x1x64x256xf32>) -> () - func.return -} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir deleted file mode 100644 index d02b3ada1aa60e..00000000000000 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_composite_functions.mlir +++ /dev/null @@ -1,277 +0,0 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ -// RUN: -stablehlo-quantize-composite-functions | FileCheck %s - - -// Tests that basic dot_general is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %2 : tensor<1x3xf32> - } -// Checks that the quantized XlaCallModule has been replaced by a CallOp, which -// calls the quantized entry function. -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> - -// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } -// Checks that the entry function is quantized for dot_general. Quantized -// dot_general outputs an i32 quantized tensor, followed by requantization to -// i8 quantized tensor. -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> -} - -// ----- - -// Tests that fused pattern for dot_general + bias is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_dot_general_with_bias_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_dot_general_with_bias_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, _original_entry_function = "composite_dot_general_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - return %2 : tensor<1x3xf32> - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3xi32>} : () -> tensor<1x3x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> - -// CHECK: func.func private @quantized_dot_general_with_bias_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_dot_general_with_bias_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> - return %1 : tensor<1x3xf32> - } -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> - -} - -// ----- - -// Tests that fused pattern for dot_general + bias with dynamic batch dimension -// is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} - func.func private @quantize_dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - return %2 : tensor - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<3xi32>} : () -> tensor<3x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor - -// CHECK: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} - func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { - %cst_0 = stablehlo.constant dense<2> : tensor<1xi32> - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor - %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor - %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> - %3 = stablehlo.concatenate %2, %cst_0, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<3xf32>, tensor<2xi32>) -> tensor - %5 = stablehlo.add %0, %4 : tensor - return %5 : tensor - } -} -// CHECK: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor> -// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) -// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> -// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> -// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor> - - -// ----- - -// Tests that basic convolution is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64, _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> - return %2 : tensor<1x3x4x2xf32> - } -// Check that the quantized XlaCallModule has been replaced by a CallOp, which -// calls the quantized entry function. -// CHECK: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> - -// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> - return %0 : tensor<1x3x4x2xf32> - } -// Checks that the entry function is quantized for convolution. Quantized -// convolution outputs an i32 quantized tensor, followed by requantization to -// i8 quantized tensor. -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> -} - -// ----- - -// Tests that fused pattern for convolution + bias is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3x4x2xf32>} : () -> tensor<1x3x4x2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_fn, _original_entry_function = "composite_conv_with_bias_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> - return %2 : tensor<1x3x4x2xf32> - } -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> - -// CHECK: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} - func.func private @composite_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> - %1 = stablehlo.add %0, %arg2 : tensor<1x3x4x2xf32> - return %1 : tensor<1x3x4x2xf32> - } -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[ARG_3]] : tensor<1x3x4x2x!quant.uniform> -// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> -// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> -} - -// ----- - -// Tests that fused pattern for convolution + bias with dynamic batch dimension -// is properly quantized. - -// The following pattern does not converge because of a bug in QuantizePass. -// TODO - b/305469508: Fix the QuantizePass to avoid this warning. -// expected-warning @+1 {{Failed to converge pattern at QuantizePass.}} -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} - func.func private @quantize_conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3x4x2>], _entry_function = @composite_conv_with_bias_dynamic_fn, _original_entry_function = "composite_conv_with_bias_dynamic_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor - %2 = "quantfork.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor - return %2 : tensor - } - -// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> -// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<2xi32>} : () -> tensor<2x!quant.uniform -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> -// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor -// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor -// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor - -// CHECK: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} -func.func private @composite_conv_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module} { - %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> - %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> - %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> - %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor - %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor - %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> - %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> - %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [3] : (tensor<2xf32>, tensor<4xi32>) -> tensor - %5 = stablehlo.add %0, %4 : tensor - return %5 : tensor - } -} -// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> -// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> -// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> -// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor> -// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) -// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> -// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [3] : (tensor<2x!quant.uniform>, tensor<4xi32>) -> tensor> -// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> -// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> - -// ----- - -// Tests that XlaCallModule op is not quantized without the quantfork.stats ops. - -module attributes {tf_saved_model.semantics} { -// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} - func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { - %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %1 : tensor<1x3xf32> - } - -// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is -// not quantized. -// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> -// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> -// CHECK: return %[[XLA_CALL_MODULE_0]] - -// CHECK: func.func private @composite_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2xf32>, %[[ARG_2:.+]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} - func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { - %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> - } - -// Check that the composite_dot_general_fn is untouched. -// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]] -// CHECK: return %[[DOT_GENERAL_0]] -} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 56161914b45dd4..1d2608599bd93b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -406,6 +406,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/quantization/common:lift_as_function_call", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:const_op_size", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", @@ -479,9 +480,11 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":passes", + "//tensorflow/compiler/mlir/lite/stablehlo:fuse_convolution_pass", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", + "//tensorflow/compiler/mlir/lite/stablehlo:unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 01d7a7b37e1907..62a6f27c8ad5f1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -99,6 +99,7 @@ cc_library( hdrs = ["convert_asset_args.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//tensorflow/compiler/mlir/quantization/common:func", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/core/protobuf:for_core_protos_cc", @@ -139,9 +140,9 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) @@ -172,13 +173,16 @@ tf_cc_test( srcs = ["constant_fold_test.cc"], deps = [ ":constant_fold", + "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc index f2a7942323f5f5..5122563c235193 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc @@ -14,12 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" +#include + +#include +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" @@ -30,38 +37,7 @@ namespace { using ::testing::NotNull; using ::testing::SizeIs; -class ConstantFoldingTest : public ::testing::Test { - protected: - ConstantFoldingTest() { - ctx_.loadDialect(); - } - - // Parses `module_op_str` to create a `ModuleOp`. Checks whether the created - // module op is valid. - OwningOpRef ParseModuleOpString(absl::string_view module_op_str) { - auto module_op_ref = parseSourceString(module_op_str, &ctx_); - EXPECT_TRUE(module_op_ref); - return module_op_ref; - } - - // Gets the function with the given name from the module. - func::FuncOp GetFunctionFromModule(ModuleOp module, - absl::string_view function_name) { - SymbolTable symbol_table(module); - return symbol_table.lookup(function_name); - } - - // Returns the first operation with the given type in the function. - template - OpType FindOperationOfType(func::FuncOp function) { - for (auto op : function.getBody().getOps()) { - return op; - } - return nullptr; - } - - MLIRContext ctx_{}; -}; +class ConstantFoldingTest : public QuantizationTestBase {}; TEST_F(ConstantFoldingTest, FoldLargeConstant) { constexpr absl::string_view kModuleCode = R"mlir( @@ -80,8 +56,10 @@ TEST_F(ConstantFoldingTest, FoldLargeConstant) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + Operation* mul_op = FindOperationOfType(test_func); SmallVector results = ConstantFoldOpIfPossible(mul_op); EXPECT_THAT(results, SizeIs(1)); @@ -106,8 +84,10 @@ TEST_F(ConstantFoldingTest, NotFoldingIdentity) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + Operation* op_to_fold = FindOperationOfType(test_func); SmallVector results = ConstantFoldOpIfPossible(op_to_fold); EXPECT_THAT(results, SizeIs(1)); @@ -135,8 +115,10 @@ TEST_F(ConstantFoldingTest, NotFoldingArgument) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + Operation* op_to_fold = FindOperationOfType(test_func); SmallVector results = ConstantFoldOpIfPossible(op_to_fold); EXPECT_THAT(results, SizeIs(1)); @@ -166,11 +148,12 @@ TEST_F(ConstantFoldingTest, FoldDepthwiseConvWeight) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); - RewritePatternSet patterns(&ctx_); - patterns.add(&ctx_); + RewritePatternSet patterns(ctx_.get()); + patterns.add(ctx_.get()); EXPECT_TRUE( succeeded(applyPatternsAndFoldGreedily(test_func, std::move(patterns)))); @@ -198,11 +181,12 @@ TEST_F(ConstantFoldingTest, DepthwiseConvWeightNotFoldable) { )mlir"; OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); - func::FuncOp test_func = - GetFunctionFromModule(*module_op_ref, "test_fold_constant"); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); - RewritePatternSet patterns(&ctx_); - patterns.add(&ctx_); + RewritePatternSet patterns(ctx_.get()); + patterns.add(ctx_.get()); EXPECT_TRUE( succeeded(applyPatternsAndFoldGreedily(test_func, std::move(patterns)))); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc index d88a1fe42cc555..8e2b3537b34d85 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -38,20 +39,6 @@ using ::mlir::tf_saved_model::LookupBoundInputOfType; using ::tensorflow::AssetFileDef; using ::tensorflow::kImportModelDefaultGraphFuncName; -// Gets the "main" function from the module. Returns an empty op iff it doesn't -// exist. -func::FuncOp GetMainFunction(ModuleOp module_op) { - const auto main_func_id = - StringAttr::get(module_op.getContext(), kImportModelDefaultGraphFuncName); - auto func_ops = module_op.getOps(); - auto main_func_itr = absl::c_find_if(func_ops, [&main_func_id](auto func_op) { - return func_op.getName() == main_func_id; - }); - - if (main_func_itr == func_ops.end()) return {}; - return *main_func_itr; -} - // Given argument attributes `arg_attrs`, returns a new set of argument // attributes where the "tf_saved_model.bound_input" attribute has been replaced // with the "tf_saved_model.index_path" attribute. `index_path` is the element @@ -130,7 +117,7 @@ void ConvertMainArgAttrs(func::FuncOp main_func_op, const int arg_idx, } // namespace FailureOr> ConvertAssetArgs(ModuleOp module_op) { - func::FuncOp main_func_op = GetMainFunction(module_op); + func::FuncOp main_func_op = FindMainFuncOp(module_op); if (!main_func_op) return failure(); SmallVector input_names = GetEntryFunctionInputs(main_func_op); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc index ef63e75e52e8c9..ee7ae1a4c6d90a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace quantization { @@ -33,11 +34,8 @@ absl::Status RunPassesOnModuleOp( absl::StatusOr> dump_file; if (mlir_dump_file_name) { - dump_file = tensorflow::quantization::MaybeEnableIrPrinting( - pass_manager, mlir_dump_file_name.value()); - if (!dump_file.ok()) { - return dump_file.status(); - } + TF_RETURN_IF_ERROR(tensorflow::quantization::MaybeEnableIrPrinting( + pass_manager, mlir_dump_file_name.value())); } if (failed(pass_manager.run(module_op))) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h index 35f243a9f9626d..06db2acb7b057f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { @@ -47,8 +48,7 @@ absl::Status RunPasses(const absl::string_view name, FuncT add_passes_func, add_passes_func(pm); mlir::StatusScopedDiagnosticHandler diagnostic_handler{&ctx}; - TF_ASSIGN_OR_RETURN(const std::unique_ptr out_dump_file, - MaybeEnableIrPrinting(pm, name)); + TF_RETURN_IF_ERROR(MaybeEnableIrPrinting(pm, name)); if (failed(pm.run(module_op))) { return absl::InternalError( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD index fa55ce0ba1e391..b465fe15e8d57c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -19,16 +19,21 @@ cc_library( hdrs = ["mlir_dump.h"], compatible_with = get_compatible_with_portable(), deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:stringpiece", ], ) @@ -39,14 +44,19 @@ tf_cc_test( deps = [ ":mlir_dump", "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", + "@stablehlo//:stablehlo_ops", ], ) @@ -57,7 +67,7 @@ tf_kernel_library( deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core:framework", - "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc index ce52e488774f60..312848e5125af3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/dump_tensor_op.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/io/compression.h" +#include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/env.h" @@ -72,7 +74,19 @@ class DumpTensorOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("node_name", &node_name)); OP_REQUIRES_OK(ctx, ctx->env()->RecursivelyCreateDir(log_dir_path)); - tensor_data_path_ = io::JoinPath(log_dir_path, file_name); + std::string tensor_data_path = io::JoinPath(log_dir_path, file_name); + std::unique_ptr tensor_data_file; + OP_REQUIRES_OK( + ctx, ctx->env()->NewWritableFile(tensor_data_path, &tensor_data_file)); + + // Turn on Zlib compression. + io::RecordWriterOptions options = + io::RecordWriterOptions::CreateRecordWriterOptions( + io::compression::kZlib); + tensor_data_writer_ = + std::make_unique(tensor_data_file.release(), options); + OP_REQUIRES(ctx, tensor_data_writer_ != nullptr, + absl::AbortedError("Could not create record writer")); // Fetch func_name and node_name from attributes and save as proto. quantization::UnitWiseQuantizationSpec::QuantizationUnit quant_unit_proto; @@ -80,28 +94,31 @@ class DumpTensorOp : public OpKernel { quant_unit_proto.set_node_name(node_name); string quant_unit_path = io::JoinPath(log_dir_path, "quant_unit.pb"); - OP_REQUIRES_OK( ctx, SaveSerializedProtoToFile(quant_unit_proto.SerializeAsString(), quant_unit_path, ctx->env())); } + ~DumpTensorOp() override { + (void)tensor_data_writer_->Flush(); + (void)tensor_data_writer_->Close(); + } + void Compute(OpKernelContext* ctx) override { - if (enabled_) { - const Tensor& tensor_data = ctx->input(0); + if (!enabled_) return; + + const Tensor& tensor_data = ctx->input(0); - TensorProto tensor_proto; - tensor_data.AsProtoTensorContent(&tensor_proto); + TensorProto tensor_proto; + tensor_data.AsProtoTensorContent(&tensor_proto); - OP_REQUIRES_OK(ctx, - SaveSerializedProtoToFile(tensor_proto.SerializeAsString(), - tensor_data_path_, ctx->env())); - } + OP_REQUIRES_OK(ctx, tensor_data_writer_->WriteRecord( + tensor_proto.SerializeAsString())); } private: - std::string tensor_data_path_; bool enabled_; + std::unique_ptr tensor_data_writer_; }; REGISTER_KERNEL_BUILDER(Name("DumpTensor").Device(DEVICE_CPU), DumpTensorOp); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc index 1d1a54e5a88db5..7a19c7b5579617 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc @@ -14,22 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" +#include #include #include #include +#include +#include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/stringpiece.h" namespace tensorflow { namespace quantization { @@ -63,39 +75,138 @@ absl::StatusOr GetMlirDumpDir() { return dump_dir; } +// A simple wrapper of tsl::WritableFile so that mlir Pass infra can use it. +class WritableFileWrapper : public llvm::raw_ostream { + public: + ~WritableFileWrapper() override = default; + static absl::StatusOr> Create( + const std::string& filepath) { + std::unique_ptr file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(filepath, &file)); + return absl::WrapUnique(new WritableFileWrapper(std::move(file))); + } + + private: + explicit WritableFileWrapper(std::unique_ptr file) + : file_(std::move(file)) { + SetUnbuffered(); + } + + uint64_t current_pos() const override { + int64_t position; + if (file_->Tell(&position).ok()) { + return position; + } else { + return -1; + } + } + + void write_impl(const char* ptr, size_t size) override { + if (file_ && !file_->Append(tsl::StringPiece(ptr, size)).ok()) { + file_ = nullptr; + } + } + + std::unique_ptr file_; +}; + // Creates a new file to dump the intermediate MLIRs by prefixing the // `dump_file_name` with the value of the TF_QUANT_MLIR_DUMP_PREFIX env // variable. Returns absl::FailedPreconditionError if the env variable is not // set or set to an empty string. -absl::StatusOr> CreateMlirDumpFile( +absl::StatusOr> CreateMlirDumpFile( const absl::string_view dump_file_name) { const absl::StatusOr dump_dir = GetMlirDumpDir(); if (!dump_dir.ok()) { return dump_dir.status(); } - auto *env = tsl::Env::Default(); - const tsl::Status status = env->RecursivelyCreateDir(*dump_dir); - if (!status.ok()) { - return status; - } + auto* env = tsl::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(*dump_dir)); - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream const std::string dump_file_path = tsl::io::JoinPath(*dump_dir, dump_file_name); - auto dump_file = std::make_unique(dump_file_path, ec); - if (ec) { - return absl::InternalError(absl::StrFormat( - "Unable to open file: %s, error: %s", dump_file_path, ec.message())); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr file, + WritableFileWrapper::Create(dump_file_path)); LOG(INFO) << "IR dump file created: " << dump_file_path; - return dump_file; + return file; } +class PrinterConfig : public mlir::PassManager::IRPrinterConfig { + public: + explicit PrinterConfig( + absl::string_view dump_file_prefix, bool print_module_scope = false, + bool print_after_only_on_change = true, + mlir::OpPrintingFlags op_printing_flags = mlir::OpPrintingFlags()) + : mlir::PassManager::IRPrinterConfig( + print_module_scope, print_after_only_on_change, + /*printAfterOnlyOnFailure=*/false, op_printing_flags), + mlir_pass_count_(1), + dump_file_prefix_(dump_file_prefix) {} + + void printBeforeIfEnabled(mlir::Pass* pass, mlir::Operation* op, + PrintCallbackFn print_callback) override { + Dump(pass, print_callback, /*is_before=*/true); + } + + void printAfterIfEnabled(mlir::Pass* pass, mlir::Operation* op, + PrintCallbackFn print_callback) override { + Dump(pass, print_callback, /*is_before=*/false); + } + + private: + int64_t mlir_pass_count_; + absl::string_view dump_file_prefix_; + // Map from pass ptr to dump files and pass number. + // + // Each pass has unique and stable pointer, even for passes with the same + // name. E.g. a PassManager could have multiple Canonicalizer passes. + // We use this property to uniquely determine a Pass in a PassManager. + // + // If multiple consecutive func passes are applied to a Module. PassManager + // will iterate over the func in the outer loop and apply the passes in the + // inner loop. This may cause passes to run out-of-order. But the 1st runs of + // each pass are still in-order. So we use pass_to_number_map_ to keep track + // of the number for each pass. + llvm::DenseMap> + pass_to_dump_file_before_map_; + llvm::DenseMap> + pass_to_dump_file_after_map_; + llvm::DenseMap pass_to_number_map_; + + // Get the unique number for each pass. + int64_t GetPassNumber(mlir::Pass* pass) { + if (!pass_to_number_map_.contains(pass)) { + pass_to_number_map_[pass] = mlir_pass_count_++; + } + return pass_to_number_map_[pass]; + } + + void Dump(mlir::Pass* pass, PrintCallbackFn print_callback, bool is_before) { + auto& pass_to_dump_file_map = is_before ? pass_to_dump_file_before_map_ + : pass_to_dump_file_after_map_; + if (!pass_to_dump_file_map.contains(pass)) { + std::string filename = llvm::formatv( + "{0}_{1,0+4}_{2}_{3}.mlir", dump_file_prefix_, GetPassNumber(pass), + pass->getName().str(), is_before ? "before" : "after"); + absl::StatusOr> dump_file = + CreateMlirDumpFile(filename); + if (!dump_file.ok()) { + LOG(WARNING) << "Failed to dump MLIR module to " << filename; + return; + } + pass_to_dump_file_map[pass] = std::move(*dump_file); + } + + return print_callback(*(pass_to_dump_file_map[pass])); + } +}; + } // namespace -void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm) { +void EnableIrPrinting(mlir::PassManager& pm, + absl::string_view file_name_prefix) { mlir::OpPrintingFlags flag{}; flag.useLocalScope().elideLargeElementsAttrs().enableDebugInfo(); @@ -112,39 +223,23 @@ void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm) { // `PassManager::enableIRPrinting`, except for the `printModuleScope` // parameter, which is true by default. It is set to false to avoid the dump // file size becoming too large when the passes are running on a large model. - pm.enableIRPrinting( - /*shouldPrintBeforePass=*/[](mlir::Pass *, - mlir::Operation *) { return true; }, - /*shouldPrintAfterPass=*/ - [](mlir::Pass *, mlir::Operation *) { return true; }, - /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, - /*printAfterOnlyOnFailure=*/false, out_stream, flag); - - LOG(INFO) << "IR dump for TensorFlow quantization pipeline enabled."; + pm.enableIRPrinting(std::make_unique( + file_name_prefix, /*print_module_scope=*/false, + /*print_after_only_on_change=*/true, flag)); } // TODO(b/259374854): Create tests for MaybeEnableIrPrinting. -absl::StatusOr> MaybeEnableIrPrinting( - mlir::PassManager &pm, const absl::string_view name) { +absl::Status MaybeEnableIrPrinting(mlir::PassManager& pm, + absl::string_view file_name_prefix) { if (!VLOG_IS_ON(1)) { LOG(INFO) << "Verbosity level too low to enable IR printing."; - return nullptr; + return absl::OkStatus(); } - absl::StatusOr> dump_file = - CreateMlirDumpFile(/*dump_file_name=*/absl::StrCat(name, ".mlir")); - if (absl::IsFailedPrecondition(dump_file.status())) { - // Requirements for enabling IR dump are not met. IR printing will not be - // enabled. - LOG(WARNING) << dump_file.status(); - return nullptr; - } else if (!dump_file.ok()) { - return dump_file.status(); - } - - EnableIrPrinting(**dump_file, pm); + EnableIrPrinting(pm, file_name_prefix); - return dump_file; + LOG(INFO) << "IR dump for TensorFlow quantization pipeline enabled."; + return absl::OkStatus(); } } // namespace quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h index 803cd39a0a5bae..38a9c4fae4f912 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h @@ -15,27 +15,29 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ -#include - -#include "absl/status/statusor.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Pass/PassManager.h" // from @llvm-project namespace tensorflow { namespace quantization { -// Enables IR printing for `pm`. When the passes are run, the IRs will be dumped -// to `out_stream`. -void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm); +// Enables IR printing for `pm`. When the passes are run, each pass will dump to +// its own file with prefix `file_name_prefix`. +void EnableIrPrinting(mlir::PassManager &pm, + absl::string_view file_name_prefix); // If verbosity level >= 1, this will dump intermediate IRs of passes to a file. -// The file path is given by prefixing `name`.mlir with the value of the -// TF_QUANT_MLIR_DUMP_PREFIX env variable. Returns `nullptr` iff the verbosity -// level < 1 or TF_QUANT_MLIR_DUMP_PREFIX is not set or set to an empty string. -// The returned ostream instance should live until the pass run is complete. -absl::StatusOr> MaybeEnableIrPrinting( - mlir::PassManager &pm, absl::string_view name); +// The dumped mlir files with be under a directory determined by +// the TF_QUANT_MLIR_DUMP_PREFIX env variable. The PassManager will dump to a +// new file for each pass. The file name will have the format +// {file_name_prefix}_{pass_number}_{pass_name}_{before|after}.mlir. +// * `file_name_prefix` is from input. +// * `pass_number` increments from 1 for each pass. +// * `pass_name` is the name of the pass. +// * `before|after` indicates whether the dump occurs before or after the pass. +absl::Status MaybeEnableIrPrinting(mlir::PassManager &pm, + absl::string_view file_name_prefix); } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc index a7162d9a05a4ff..c3034f4294b13d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc @@ -16,23 +16,30 @@ limitations under the License. #include #include +#include #include "absl/cleanup/cleanup.h" -#include "absl/status/statusor.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinDialect.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/test.h" namespace tensorflow { namespace quantization { -namespace { +namespace mlir_dump_test { class NoOpPass : public mlir::PassWrapper> { @@ -69,12 +76,7 @@ class ParentPass pm.addPass(CreateNoOpPass()); - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream - const std::string tmp_dump_filename = - tsl::io::GetTempFilename(/*extension=*/".mlir"); - llvm::raw_fd_ostream dump_file{tmp_dump_filename, ec}; - - EnableIrPrinting(dump_file, pm); + EnableIrPrinting(pm, "dump2"); if (failed(pm.run(module_op))) { signalPassFailure(); @@ -86,41 +88,88 @@ std::unique_ptr> CreateParentPass() { return std::make_unique(); } -TEST(EnableIrPrintingTest, PassSuccessfullyRuns) { - mlir::MLIRContext ctx{}; +} // namespace mlir_dump_test - mlir::PassManager pm = {&ctx}; - pm.addPass(CreateNoOpPass()); +namespace { - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream - const std::string tmp_dump_filename = - tsl::io::GetTempFilename(/*extension=*/".mlir"); - llvm::raw_fd_ostream dump_file{tmp_dump_filename, ec}; +using namespace tensorflow::quantization::mlir_dump_test; - EnableIrPrinting(dump_file, pm); +class EnableIrPrintingTest : public ::testing::Test { + protected: + EnableIrPrintingTest() : env_(tsl::Env::Default()) { + if (!tsl::io::GetTestUndeclaredOutputsDir(&test_dir_)) { + test_dir_ = tsl::testing::TmpDir(); + } + } - mlir::OpBuilder builder(&ctx); - auto module_op = builder.create(builder.getUnknownLoc()); - // Destroy by calling destroy() to avoid memory leak since it is allocated - // with malloc(). - const absl::Cleanup module_op_cleanup = [module_op] { module_op->destroy(); }; + void SetUp() override { + tsl::setenv("TF_QUANT_MLIR_DUMP_PREFIX", test_dir_.c_str(), 1); - const mlir::LogicalResult result = pm.run(module_op); + mlir::DialectRegistry dialects; + dialects.insert(); + ctx_ = std::make_unique(dialects); + ctx_->loadAllAvailableDialects(); + } + + void TearDown() override { + // Delete files in the test directory. + std::vector files; + TF_ASSERT_OK( + env_->GetMatchingPaths(tsl::io::JoinPath(test_dir_, "*"), &files)); + for (const std::string& file : files) { + TF_ASSERT_OK(env_->DeleteFile(file)); + } + } + + tsl::Env* env_; + std::string test_dir_; + std::unique_ptr ctx_; +}; + +TEST_F(EnableIrPrintingTest, PassSuccessfullyRuns) { + mlir::PassManager pm = {ctx_.get()}; + pm.addPass(CreateNoOpPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + + EnableIrPrinting(pm, "dump"); + + constexpr absl::string_view program = R"mlir( +module{ + func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { + return %arg0 : tensor<10xf32> + } + func.func @func1(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<10xf32> + %1 = stablehlo.add %arg0, %arg1 : tensor<10xf32> + return %0 : tensor<10xf32> + } +})mlir"; + auto module_op = mlir::parseSourceString(program, ctx_.get()); + + const mlir::LogicalResult result = pm.run(module_op.get()); EXPECT_FALSE(failed(result)); + + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, + "dump_0001_tensorflow::quantization::mlir_dump_test" + "::NoOpPass_before.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, "dump_0002_Canonicalizer_before.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, "dump_0002_Canonicalizer_after.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, "dump_0003_Canonicalizer_before.mlir"))); } -TEST(EnableNestedIrPrintingTest, PassSuccessfullyRuns) { +TEST_F(EnableIrPrintingTest, NestedPassSuccessfullyRuns) { mlir::MLIRContext ctx{}; mlir::PassManager pm = {&ctx}; pm.addPass(CreateParentPass()); - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream - const std::string tmp_dump_filename = - tsl::io::GetTempFilename(/*extension=*/".mlir"); - llvm::raw_fd_ostream dump_file{tmp_dump_filename, ec}; - - EnableIrPrinting(dump_file, pm); + EnableIrPrinting(pm, "dump"); mlir::OpBuilder builder(&ctx); auto module_op = builder.create(builder.getUnknownLoc()); @@ -130,6 +179,15 @@ TEST(EnableNestedIrPrintingTest, PassSuccessfullyRuns) { const mlir::LogicalResult result = pm.run(module_op); EXPECT_FALSE(failed(result)); + + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, + "dump_0001_tensorflow::quantization::mlir_dump_test" + "::ParentPass_before.mlir"))); + TF_EXPECT_OK(tsl::Env::Default()->FileExists( + tsl::io::JoinPath(test_dir_, + "dump2_0001_tensorflow::quantization::mlir_dump_test" + "::NoOpPass_before.mlir"))); } } // namespace } // namespace quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD index fa201ff6a716bc..7042b6c5b17cdb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD @@ -20,8 +20,8 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index 3a81ff0dfd2c91..fb13da8489a81c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -14,17 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" -#include -#include #include #include #include #include "absl/container/flat_hash_set.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index 595534433849ee..a6ad3e34835672 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" @@ -54,7 +55,76 @@ namespace { using DebuggerType = tensorflow::quantization::DebuggerOptions::DebuggerType; using DebuggerOptions = tensorflow::quantization::DebuggerOptions; +constexpr StringRef kEntryFuncAttrName = "_entry_function"; +constexpr StringRef kOriginalEntryFuncAttrName = "_original_entry_function"; constexpr StringRef kCompositeFuncPrefix = "composite_"; +constexpr StringRef kEmptyNodeName = "_empty_node"; + +template +std::pair GetFuncNameAndNodeName( + LiftedOp op, const FlatSymbolRefAttr &f_attr) { + static_assert(false, + "GetFuncNameAndNodeName for call_op is not implemented."); +} + +// Returns a pair: `func_name` and `node_name` for the lifted function. In TF +// quantizer, both are filled. For StableHLO quantizer, the func_name is only +// filled and node_name is always set to "_empty_node". +template <> +std::pair +GetFuncNameAndNodeName(TF::PartitionedCallOp call_op, + const FlatSymbolRefAttr &f_attr) { + std::optional quant_unit = + FindQuantizationUnitFromLoc(call_op->getLoc()); + return std::make_pair(quant_unit->func_name(), quant_unit->node_name()); +} + +template <> +std::pair GetFuncNameAndNodeName( + TF::XlaCallModuleOp call_op, const FlatSymbolRefAttr &f_attr) { + return std::make_pair(f_attr.getValue().str(), kEmptyNodeName.str()); +} + +template +Operation *DuplicateOp(LiftedOp op, PatternRewriter &rewriter, + const StringAttr &new_ref_func_name) { + static_assert(false, "DuplicateOp for call_op is not implemented."); +} + +template <> +Operation *DuplicateOp( + TF::PartitionedCallOp call_op, PatternRewriter &rewriter, + const StringAttr &new_ref_func_name) { + // Create PartitionedCallOp to the copied composite function. This + // PartitionedCallOp does not have kQuantTraitAttrName, and therefore won't + // get quantized. + auto new_call_op = rewriter.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), + FlatSymbolRefAttr::get(new_ref_func_name)); + return new_call_op; +} + +template <> +Operation *DuplicateOp( + TF::XlaCallModuleOp call_op, PatternRewriter &rewriter, + const StringAttr &new_ref_func_name) { + // Create XlaCallModuleOp to the copied composite function. This + // XlaCallModuleOp does not have kQuantTraitAttrName, and therefore won't get + // quantized. + auto new_call_op = rewriter.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), + call_op.getVersionAttr(), call_op.getModuleAttr(), call_op.getSoutAttr()); + new_call_op->setAttr(kEntryFuncAttrName, + rewriter.getStringAttr(new_ref_func_name.getValue())); + new_call_op->setAttrs(call_op->getAttrs()); + new_call_op->removeAttr(rewriter.getStringAttr(kQuantTraitAttrName)); + + FlatSymbolRefAttr new_func_name_attr = + FlatSymbolRefAttr::get(rewriter.getContext(), new_ref_func_name); + new_call_op->setAttr(kEntryFuncAttrName, new_func_name_attr); + new_call_op->setAttr(kOriginalEntryFuncAttrName, new_ref_func_name); + return new_call_op; +} // AddDumpTensorOp pass adds DumpTensorOp - which saves entire value of its // input into a file - to quantizable layer's output. @@ -110,49 +180,66 @@ class AddDumpTensorOpPass std::string log_dir_path_ = "/tmp/dumps"; }; -class AddDumpTensorOp : public OpRewritePattern { +template +class AddDumpTensorOp : public OpRewritePattern { public: // Does not take ownership of context, which must refer to a valid value that // outlives this object. explicit AddDumpTensorOp(MLIRContext *context, DebuggerType debugger_type, std::string log_dir_path) - : OpRewritePattern(context), + : OpRewritePattern(context), debugger_type_(debugger_type), log_dir_path_(std::move(log_dir_path)) {} private: - DebuggerType debugger_type_; - std::string log_dir_path_; + SmallVector CreateDumpAttributes( + PatternRewriter &rewriter, const StringRef folder_name, + const StringRef file_name, const bool enabled, const StringRef func_name, + const StringRef node_name) const { + SmallVector dump_attributes{ + rewriter.getNamedAttr("log_dir_path", + rewriter.getStringAttr(folder_name)), + rewriter.getNamedAttr("file_name", rewriter.getStringAttr(file_name)), + // The op is disabled by default. Otherwise, values will be saved + // during calibration. + rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), + rewriter.getNamedAttr("func_name", rewriter.getStringAttr(func_name)), + rewriter.getNamedAttr("node_name", rewriter.getStringAttr(node_name)), + }; + return dump_attributes; + } - LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, - PatternRewriter &rewriter) const override { - const auto f_attr = call_op.getFAttr().dyn_cast(); - if (!call_op->hasAttr(kQuantTraitAttrName)) { - return failure(); - } - if (!f_attr.getValue().starts_with(kCompositeFuncPrefix)) { - return failure(); - } + StringAttr DuplicateFunction(Operation *op, + const FlatSymbolRefAttr &f_attr) const { + ModuleOp module = op->getParentOfType(); + SymbolTable symbol_table(module); - // For now, only support ops with 1 results - if (call_op->getNumResults() != 1) return failure(); + const func::FuncOp ref_func = + dyn_cast_or_null(symbol_table.lookup(f_attr.getValue())); + func::FuncOp new_ref_func = dyn_cast(ref_func->clone()); + return symbol_table.insert(new_ref_func); + } - Value result = call_op->getResult(0); + LogicalResult match(LiftedOpT op) const override { + if (!op->hasAttr(kQuantTraitAttrName) || op->getNumResults() != 1) { + return failure(); + } - // If one of the user is DumpTensorOp, do nothing + Value result = op->getResult(0); for (auto user : result.getUsers()) { if (dyn_cast_or_null(user)) return failure(); } - rewriter.setInsertionPointAfterValue(result); - - std::optional quant_unit = - FindQuantizationUnitFromLoc(call_op->getLoc()); + const FlatSymbolRefAttr f_attr = GetFuncAttr(op); + if (!f_attr.getValue().starts_with(kCompositeFuncPrefix)) return failure(); + return success(); + } - if (!quant_unit.has_value()) return failure(); + void rewrite(LiftedOpT op, PatternRewriter &rewriter) const override { + // Only support ops with 1 results + Value result = op->getResult(0); + rewriter.setInsertionPointAfterValue(result); - auto folder_name = - tensorflow::io::JoinPath(log_dir_path_, f_attr.getValue()); // In Whole model, we first need to set file_name as // unquantized_tensor_data.pb as it is used by unquantized dump model. // After saving unquantized dump model, the file name will be changed to @@ -161,77 +248,56 @@ class AddDumpTensorOp : public OpRewritePattern { // as quantized_tensor_data.pb here. // TODO: b/296933893 - Refactor the debugger code when no quantize option // is added - auto file_name = + std::string file_name = debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_WHOLE_MODEL ? "unquantized_tensor_data.pb" : "quantized_tensor_data.pb"; - SmallVector dump_attributes{ - rewriter.getNamedAttr("log_dir_path", - rewriter.getStringAttr(folder_name)), - rewriter.getNamedAttr("file_name", rewriter.getStringAttr(file_name)), - // The op is disabled by default. Otherwise, values will be saved - // during calibration. - rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), - rewriter.getNamedAttr("func_name", - rewriter.getStringAttr(quant_unit->func_name())), - rewriter.getNamedAttr("node_name", - rewriter.getStringAttr(quant_unit->node_name())), - }; + const FlatSymbolRefAttr f_attr = GetFuncAttr(op); + + // In TF::PartitionedCallOp case, func_name and node_name are filled. + // But in TF::XlaCallModuleOp case, node_name is `kEmptyNodeName` since + // debugging and selective quantization of StableHLO Quantizer only uses + // func_name for op matching. + auto [func_name, node_name] = GetFuncNameAndNodeName(op, f_attr); + std::string folder_name = + tensorflow::io::JoinPath(log_dir_path_, f_attr.getValue()); - rewriter.create(call_op->getLoc(), TypeRange{}, result, + // Attach DumpTensorOp to its output layer. + SmallVector dump_attributes = + CreateDumpAttributes(rewriter, folder_name, file_name, + /*enabled=*/false, func_name, node_name); + rewriter.create(op->getLoc(), TypeRange{}, result, dump_attributes); // Per-layer mode. if (debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_INT_PER_LAYER || debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_FLOAT_PER_LAYER) { - auto module = call_op->getParentOfType(); - SymbolTable symbol_table(module); - - // Copy composite function of quantizable layer. - const mlir::func::FuncOp ref_func = dyn_cast_or_null( - symbol_table.lookup(f_attr.getValue())); - mlir::func::FuncOp new_ref_func = - dyn_cast(ref_func->clone()); - const StringAttr new_ref_func_name = symbol_table.insert(new_ref_func); - - // Create PartitionedCallOp to the copied composite function. - // This PartitionedCallOp does not have kQuantTraitAttrName, and therefore - // won't get quantized. - auto ref_call_op = rewriter.create( - call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), - FlatSymbolRefAttr::get(new_ref_func_name)); - - // Attach DumpTensorOp to its output unquantized layer. - SmallVector dump_attributes{ - rewriter.getNamedAttr("log_dir_path", - rewriter.getStringAttr(folder_name)), - rewriter.getNamedAttr("file_name", rewriter.getStringAttr( - "unquantized_tensor_data.pb")), - rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(false)), - rewriter.getNamedAttr( - "func_name", rewriter.getStringAttr(quant_unit->func_name())), - rewriter.getNamedAttr( - "node_name", rewriter.getStringAttr(quant_unit->node_name())), - }; - - rewriter.create(call_op->getLoc(), TypeRange{}, - ref_call_op.getResult(0), - dump_attributes); + // Duplicate composite function and op of quantizable layer for creating + // unquantized layer. + StringAttr new_ref_func_name = DuplicateFunction(op, f_attr); + Operation *new_op = DuplicateOp(op, rewriter, new_ref_func_name); + + // Attach second DumpTensorOp to its output unquantized layer. + SmallVector dump_attributes = CreateDumpAttributes( + rewriter, folder_name, /*file_name=*/"unquantized_tensor_data.pb", + /*enabled=*/false, func_name, node_name); + rewriter.create(op.getLoc(), TypeRange{}, + new_op->getResult(0), dump_attributes); if (debugger_type_ == DebuggerOptions::DEBUGGER_TYPE_FLOAT_PER_LAYER) { // Swap all uses between call_op and ref_call_op, except for the // particular use that owns DumpTensor. rewriter.replaceUsesWithIf( - call_op.getResult(0), ref_call_op.getResult(0), - [](OpOperand &use) -> bool { + op.getResult(0), new_op->getResult(0), [](OpOperand &use) -> bool { return !isa(use.getOwner()); }); } } - - return success(); } + + DebuggerType debugger_type_; + std::string log_dir_path_; }; static PassRegistration pass; @@ -241,7 +307,10 @@ void AddDumpTensorOpPass::runOnOperation() { RewritePatternSet patterns(ctx); ModuleOp module = getOperation(); - patterns.add(ctx, debugger_type_, log_dir_path_); + patterns.add, + AddDumpTensorOp>(ctx, debugger_type_, + log_dir_path_); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { module.emitError() << "quant-add-dump-tensor-op failed."; signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc index 6e94beb6b0a057..f1f65a1a183371 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -107,20 +108,6 @@ class MergeInitializerFunctionOpsToMainPass } }; -// Gets the "main" function from the module. Returns an empty op iff it doesn't -// exist. -func::FuncOp GetMainFunction(ModuleOp module_op) { - const auto main_func_id = - StringAttr::get(module_op.getContext(), kImportModelDefaultGraphFuncName); - auto func_ops = module_op.getOps(); - auto main_func_itr = absl::c_find_if(func_ops, [&main_func_id](auto func_op) { - return func_op.getName() == main_func_id; - }); - - if (main_func_itr == func_ops.end()) return {}; - return *main_func_itr; -} - // Returns true iff func_op has either no Region or the body has no Blocks. bool IsFuncOpEmpty(func::FuncOp func_op) { return func_op->getNumRegions() == 0 || func_op.getBody().empty(); @@ -336,7 +323,7 @@ void MergeInitializerFunctionOpsToMainPass::runOnOperation() { ModuleOp module_op = getOperation(); MLIRContext* ctx = module_op.getContext(); - func::FuncOp main_func_op = GetMainFunction(module_op); + func::FuncOp main_func_op = FindMainFuncOp(module_op); if (!main_func_op) { module_op.emitError("Main function op not found."); return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index 886a27011b1825..ebdd374288a065 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project @@ -120,6 +121,27 @@ bool HasEqualElementSize(Value val1, Value val2, ArrayRef val1_indices, return val1_result == val2_result; } +// Checks if a shape has dim sizes of all ones except the right most dim. +bool ReshapableTo1DTensor(ShapedType rhs_shape) { + for (auto rank = 0; rank < rhs_shape.getRank() - 1; rank++) { + if (rhs_shape.getDimSize(rank) != 1) { + return false; + } + } + return true; +} + +Value ReshapeTo1DTensor(OpBuilder& builder, Location loc, Value value) { + auto shape = value.getType().cast(); + if (shape.getRank() != 1) { + SmallVector new_shape; + new_shape.push_back(shape.getNumElements()); + value = builder.create( + loc, value, Create1DConstValue(builder, loc, new_shape)); + } + return ConstantFoldOpIfPossible(value.getDefiningOp()).front(); +} + // Matches convolution op with "NHWC" data format or matmul op with false adj_y. // The list of supported ops in this function is: // - Conv2DOp diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index 30e298dd6e7048..d75a01be7d2182 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -82,6 +82,13 @@ class HasEqualElementSize shape_1, list shape_2> : Constraint< "llvm::ArrayRef({" # !interleave(shape_2, ", ") # "}))">, "Checks if the given dimensions contain the same number of elements.">; +def ReshapableTo1DTensor : Constraint< + CPred<"quant::ReshapableTo1DTensor($0.getType().cast())">, + "Checks if the value dims are all ones except the right most dim">; + +def ReshapeTo1DTensor : NativeCodeCall< + "quant::ReshapeTo1DTensor($_builder, $_loc, $0)">; + def HasEqualShape : Constraint().hasRank() && " "$1.getType().cast().hasRank() && " @@ -112,7 +119,29 @@ def ConvertAddToBiasAdd : Pat< (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), (TF_BiasAddOp $conv_out, $add_rhs, (CreateStringAttr<"NHWC">)), [(HasRankOf<1> $add_rhs_value), - (HasEqualElementSize<[-1], [0]> $conv_out, $add_rhs)]>; + (HasEqualElementSize<[-1], [0]> $conv_out, $add_rhs)], [], (addBenefit -1)>; + +// Convert conv+sub+mul pattern to conv+mul+add. +// (conv - sub) * mul -> conv * mul + (-sub) * mul +// +// This is needed to support Conv+BatchNorm pattern from Jax models converted +// using jax2tf w/o native serialization. Note that Jax2tf patterns always +// extend bias shapes to a rank of 4, e.g. 1x1x1x5. +def ConvertSubMulToMulAdd : Pat< + (TF_MulOp + (TF_SubOp + (SupportedAffineOpMatcher $conv_out, $input, $weight), + (TF_ConstOp:$sub_rhs IsFloatElementsAttr:$sub_rhs_value)), + (TF_ConstOp:$mul_rhs IsFloatElementsAttr:$mul_rhs_value)), + (TF_AddV2Op + (TF_MulOp $conv_out, (ReshapeTo1DTensor $mul_rhs)), + (TF_MulOp + (TF_NegOp (ReshapeTo1DTensor $sub_rhs)), + (ReshapeTo1DTensor $mul_rhs))), + [(ReshapableTo1DTensor $mul_rhs), + (ReshapableTo1DTensor $sub_rhs), + (HasEqualElementSize<[-1], [-1]> $conv_out, $mul_rhs), + (HasEqualElementSize<[-1], [-1]> $conv_out, $sub_rhs)]>; // TODO(b/278493977): Create generic implementation of lifting any fused op // with any reshaping op @@ -128,6 +157,7 @@ def ConvertAddWithReshapeToBiasAddWithReshape : Pat< (HasEqualElementSize<[-1], [0]> $reshape_out, $add_rhs)]>; // Fuse consecutive BiasAddOp and an AddV2Op. +// We also handle the case where add_rhs has rank 4. def FuseBiasAndAddV2 : Pat< (TF_AddV2Op (TF_BiasAddOp:$bias_add @@ -135,9 +165,10 @@ def FuseBiasAndAddV2 : Pat< (TF_ConstOp:$bias IsFloatElementsAttr:$bias_value), $data_format), (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), (TF_BiasAddOp - $conv_out, (TF_AddV2Op $bias, $add_rhs), $data_format), + $conv_out, (TF_AddV2Op $bias, (ReshapeTo1DTensor $add_rhs)), $data_format), [(HasOneUse $bias_add), - (HasEqualShape $bias_value, $add_rhs_value)]>; + (ReshapableTo1DTensor $add_rhs), + (HasEqualElementSize<[-1], [-1]> $bias, $add_rhs)]>; // Fuse AffineOp followed by an MulOp patterns. def FuseAffineOpAndMul : Pat< diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index b5fb96396f7ef9..20ffa2adcfa969 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -414,7 +414,7 @@ void PrepareQuantizePass::runOnOperation() { ApplyQuantizationParamsPropagation( func, is_signed, /*bit_width=*/8, !enable_per_channel_quantization_, GetTFOpQuantSpec, GetTfQuantScaleSpec, infer_tensor_range, - quant_specs_.legacy_float_scale); + quant_specs_.legacy_float_scale, /*is_qdq_conversion=*/false); RewritePatternSet patterns2(ctx); patterns2.add(ctx); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 5248d95c9f9e10..490cc1fc889b91 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -11,6 +11,7 @@ package( default_visibility = [ "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", + "//tensorflow/lite:__subpackages__", "//tensorflow/python:__subpackages__", "//tensorflow/tools/pip_package/v2:__subpackages__", ], @@ -345,6 +346,7 @@ tf_py_strict_test( "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/framework:test_lib", "//tensorflow/python/lib/io:file_io", + "//tensorflow/python/lib/io:tf_record", "//tensorflow/python/module", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index a67a6451341e4e..ca7027dbe312bd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io +from tensorflow.python.lib.io import tf_record from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -1790,6 +1791,15 @@ def gen_data() -> repr_dataset.RepresentativeDataset: 'enable_per_channel_quantization': True, 'dilations': [1, 2, 2, 1], }, + { + 'testcase_name': 'with_bias_and_relu6_to_stablehlo_per_channel', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.STABLEHLO, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': True, + }, ) @test_util.run_in_graph_and_eager_modes def test_conv_ptq_model( @@ -1950,6 +1960,10 @@ def data_gen() -> repr_dataset.RepresentativeDataset: ), ) self.assertFalse(self._contains_op(output_graphdef, 'Conv2D')) + elif target_opset == quant_opts_pb2.STABLEHLO: + # This is to verify the invocation of StableHLO quantizer works. More + # thorough functional tests are in StableHLO quantizer directory. + self.assertTrue(self._contains_op(output_graphdef, 'XlaCallModule')) else: self.assertTrue(self._contains_quantized_function_call(output_graphdef)) self.assertFalse(self._contains_op(output_graphdef, 'FusedBatchNormV3')) @@ -5831,7 +5845,7 @@ def test_while_op_model( class DebuggerTest(quantize_model_test_base.QuantizedModelTest): - def _run_model_in_sess(self, model_dir, tags, signature_key, sample_input): + def _run_model_in_sess(self, model_dir, tags, signature_key, sample_inputs): with tensorflow.compat.v1.Session(graph=tensorflow.Graph()) as sess: meta_graph = saved_model_loader.load(sess, tags, export_dir=model_dir) signature_def = meta_graph.signature_def[signature_key] @@ -5843,13 +5857,26 @@ def _run_model_in_sess(self, model_dir, tags, signature_key, sample_input): for output_tensor_info in signature_def.outputs.values() ] - feed_dict = {} - for input_key, input_value in sample_input.items(): - input_tensor_name = signature_def.inputs[input_key].name - feed_dict[input_tensor_name] = input_value + output_values = [] + for sample_input in sample_inputs: + feed_dict = {} + for input_key, input_value in sample_input.items(): + input_tensor_name = signature_def.inputs[input_key].name + feed_dict[input_tensor_name] = input_value - # Obtain the output of the model. - return sess.run(output_tensor_names, feed_dict=feed_dict)[0] + # Obtain the output of the model. + output_values.append( + sess.run(output_tensor_names, feed_dict=feed_dict)[0] + ) + return output_values + + def _read_tensor_array_file(self, file_path): + tensor_protos = [] + for raw_record in tf_record.tf_record_iterator(file_path, options='ZLIB'): + tensor_protos.append( + tensorflow.make_ndarray(tensor_pb2.TensorProto.FromString(raw_record)) + ) + return np.array(tensor_protos) @parameterized.named_parameters( { @@ -5926,9 +5953,10 @@ def data_gen() -> repr_dataset.RepresentativeDataset: converted_model.signatures._signatures.keys(), {'serving_default'} ) - sample_input = { - 'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3)) - } + sample_inputs = [ + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + ] # Check if output of the model and value saved by DumpTensorOp matches. # Verify for both unquantized model and quantized model. @@ -5936,24 +5964,19 @@ def data_gen() -> repr_dataset.RepresentativeDataset: [unquantized_dump_model_path, 'unquantized_tensor_data.pb'], [self._output_saved_model_path, 'quantized_tensor_data.pb'], ]: - output_value = self._run_model_in_sess( - model_path, tags, 'serving_default', sample_input + output_values = self._run_model_in_sess( + model_path, tags, 'serving_default', sample_inputs ) # Find the dump file and parse it. folder = os.path.join(log_dir_path, os.listdir(log_dir_path)[0]) dump_file_path = os.path.join(log_dir_path, folder, file_name) - - dump_file_proto = tensor_pb2.TensorProto.FromString( - open(dump_file_path, 'rb').read() - ) - - dump_file_numpy = tensorflow.make_ndarray(dump_file_proto) + dump_file_numpy = self._read_tensor_array_file(dump_file_path) # Since the model only has one conv2d and its output is directly used as # the output of the model, output of the model and conv2d's dump value # should be the same. - self.assertAllEqual(output_value, dump_file_numpy) + self.assertAllEqual(output_values, dump_file_numpy) # Verify if quant_unit.pb file was created correctly. quant_unit_file_path = os.path.join(log_dir_path, folder, 'quant_unit.pb') @@ -6070,15 +6093,16 @@ def data_gen() -> repr_dataset.RepresentativeDataset: converted_model.signatures._signatures.keys(), {'serving_default'} ) - sample_input = { - 'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3)) - } + sample_inputs = [ + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + {'input_tensor': np.random.uniform(low=0, high=1, size=(16, 3, 4, 3))}, + ] output_value_from_original_model = self._run_model_in_sess( - self._input_saved_model_path, tags, 'serving_default', sample_input + self._input_saved_model_path, tags, 'serving_default', sample_inputs ) output_value_from_debugging_model = self._run_model_in_sess( - self._output_saved_model_path, tags, 'serving_default', sample_input + self._output_saved_model_path, tags, 'serving_default', sample_inputs ) # Find the both quantized and unquantized dump file. @@ -6090,18 +6114,11 @@ def data_gen() -> repr_dataset.RepresentativeDataset: log_dir_path, folder, 'quantized_tensor_data.pb' ) - unquantized_dump_file_proto = tensor_pb2.TensorProto.FromString( - open(unquantized_dump_file_path, 'rb').read() - ) - quantized_dump_file_proto = tensor_pb2.TensorProto.FromString( - open(quantized_dump_file_path, 'rb').read() - ) - - unquantized_dump_file_numpy = tensorflow.make_ndarray( - unquantized_dump_file_proto + unquantized_dump_file_numpy = self._read_tensor_array_file( + unquantized_dump_file_path ) - quantized_dump_file_numpy = tensorflow.make_ndarray( - quantized_dump_file_proto + quantized_dump_file_numpy = self._read_tensor_array_file( + quantized_dump_file_path ) # Since the model only has one conv2d and its output is directly used as @@ -6143,169 +6160,46 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest): (default in TF2) to ensure support for when TF2 is disabled. """ - @parameterized.named_parameters( - { - 'testcase_name': 'with_min_max', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX - ), - }, - { - 'testcase_name': 'with_min_max_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX - ), - }, - { - 'testcase_name': 'with_min_max_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX - ), - }, - { - 'testcase_name': 'with_average_min_max', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX - ), - }, - { - 'testcase_name': 'with_average_min_max_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX - ), - }, - { - 'testcase_name': 'with_average_min_max_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX - ), - }, - { - 'testcase_name': 'with_histogram_percentile', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_percentile_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_percentile_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_bruteforce', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_bruteforce_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_bruteforce_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, - ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_max_frequency', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + @parameterized.parameters( + parameter_combinations([{ + 'target_opset': [ + quant_opts_pb2.TF, + quant_opts_pb2.XLA, + quant_opts_pb2.UNIFORM_QUANTIZED, + ], + 'calibration_options': [ + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_max_frequency_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_max_frequency_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_symmetric', - 'target_opset': quant_opts_pb2.TF, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_symmetric_to_xla', - 'target_opset': quant_opts_pb2.XLA, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, - { - 'testcase_name': 'with_histogram_mse_symmetric_to_uq', - 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, - 'calibration_options': quant_opts_pb2.CalibrationOptions( - calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, - calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( - initial_num_bins=10, + quant_opts_pb2.CalibrationOptions( + calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, + calibration_parameters=quant_opts_pb2.CalibrationOptions.CalibrationParameters( + initial_num_bins=10, + ), ), - ), - }, + ], + }]) ) @test_util.run_in_graph_and_eager_modes def test_conv_ptq_model_by_calibration_options( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py index d9f8c9781fc4ca..902ee3e5c94e2e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py @@ -22,7 +22,6 @@ from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_algorithm from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import pywrap_calibration from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_function_lib from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model @@ -500,31 +499,6 @@ def _run_graph_for_calibration( logging.info('Calibration step complete.') -def _get_min_max_from_calibrator( - node_id: bytes, - calib_opts: quantization_options_pb2.CalibrationOptions, -) -> tuple[float, float]: - """Calculate min and max from statistics using calibration options. - - Args: - node_id: bytes of node id. - calib_opts: Calibration options used for calculating min and max. - - Returns: - (min_value, max_value): Min and max calculated using calib_opts. - - Raises: - ValueError: Unsupported calibration method is given. - """ - statistics: calibration_statistics_pb2.CalibrationStatistics = ( - pywrap_calibration.get_statistics_from_calibrator(node_id) - ) - min_value, max_value = calibration_algorithm.get_min_max_value( - statistics, calib_opts - ) - return min_value, max_value - - class PyFunctionLibrary(pywrap_function_lib.PyFunctionLibrary): """Wrapper class for overridden python method definitions. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 84959d9ba30405..512102d0a5f53c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -93,7 +93,8 @@ absl::StatusOr> RunExportPasses( } if (absl::Status pass_run_status = RunPasses( - /*name=*/export_opts.debug_name, + /*name=*/ + export_opts.debug_name, /*add_passes_func=*/ [dup_constants = export_opts.duplicate_shape_determining_constants]( mlir::PassManager &pm) { AddExportPasses(pm, dup_constants); }, @@ -163,7 +164,8 @@ absl::StatusOr QuantizeQatModel( /*deserialize_xla_call_module=*/false)); TF_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantQatStepName, /*add_passes_func=*/ + /*name=*/ + kTfQuantQatStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizeQatPasses(pm, quantization_options, kTfQuantQatStepName); }, @@ -243,7 +245,8 @@ absl::StatusOr QuantizePtqModelPreCalibration( *module_ref, QuantizationConfig())); } else { TF_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqPreCalibrationStepName, /*add_passes_func=*/ + /*name=*/ + kTfQuantPtqPreCalibrationStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizePtqPreCalibrationPasses(pm, quantization_options); }, @@ -325,12 +328,22 @@ absl::StatusOr QuantizePtqModelPostCalibration( // Use StableHLO Quantizer option if opset is specified. if (is_stablehlo) { + QuantizationConfig quantization_config{}; + quantization_config.mutable_static_range_ptq_preset() + ->set_enable_per_channel_quantized_weight( + quantization_options.enable_per_channel_quantization()); + // When targeting server TPUs quantized types should be unpacked into + // integer ops. + quantization_config.mutable_pipeline_config()->set_unpack_quantized_types( + true); + PostCalibrationComponent post_calibration_component(context.get()); TF_ASSIGN_OR_RETURN(*module_ref, post_calibration_component.Run( - *module_ref, QuantizationConfig())); + *module_ref, quantization_config)); } else { TF_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqPostCalibrationStepName, /*add_passes_func=*/ + /*name=*/ + kTfQuantPtqPostCalibrationStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizePtqPostCalibrationPasses( pm, quantization_options, kTfQuantPtqPostCalibrationStepName); @@ -405,7 +418,8 @@ absl::StatusOr QuantizePtqDynamicRange( /*run_tf_to_stablehlo=*/false, /*deserialize_xla_call_module=*/false)); TF_RETURN_IF_ERROR(RunPasses( - /*name=*/kTfQuantPtqDynamicRangeStepName, /*add_passes_func=*/ + /*name=*/ + kTfQuantPtqDynamicRangeStepName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { AddQuantizePtqDynamicRangePasses(pm, quantization_options, kTfQuantPtqDynamicRangeStepName); @@ -484,13 +498,14 @@ absl::StatusOr QuantizeWeightOnly( /*run_tf_to_stablehlo=*/false, /*deserialize_xla_call_module=*/false)); - TF_RETURN_IF_ERROR( - RunPasses(/*name=*/kTfQuantWeightOnlyStepName, /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { - AddQuantizeWeightOnlyPasses(pm, quantization_options, - kTfQuantWeightOnlyStepName); - }, - *context, *module_ref)); + TF_RETURN_IF_ERROR(RunPasses( + kTfQuantWeightOnlyStepName, + /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizeWeightOnlyPasses(pm, quantization_options, + kTfQuantWeightOnlyStepName); + }, + *context, *module_ref)); const bool unfreeze_constants = !quantization_options.freeze_all_variables(); TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 13db5fff7a8cdc..556222ef4797d5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -682,15 +682,16 @@ def _populate_quantization_options_default_values( == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 ) or ( - quantization_options.op_set == quant_opts_pb2.OpSet.XLA + quantization_options.op_set + in (quant_opts_pb2.OpSet.XLA, quant_opts_pb2.OpSet.STABLEHLO) and quantization_options.quantization_method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8 ) ): raise ValueError( 'Currently, per-channel quantization is supported for Uniform Quantized' - ' opset, weight only quantization, or XLA opset with static range' - ' quantization.' + ' opset, weight only quantization, or XLA/StableHLO opset with static' + ' range quantization.' ) if ( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index 136aab9f583030..2ca81b72aa71aa 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" @@ -56,8 +57,10 @@ void AddUnfuseMhloOpsPasses(mlir::PassManager& pm) { mlir::mhlo::createLegalizeEinsumToDotGeneralPass()); pm.addNestedPass( mlir::mhlo::createLegalizeDotToDotGeneralPass()); - pm.addNestedPass( - mlir::quant::stablehlo::createUnfuseMhloBatchNormPass()); + // Unfuse mhlo BatchNorm to primitive ops. + pm.addNestedPass(mlir::odml::createUnfuseBatchNormPass()); + // Fuse Conv + Mul to Conv. + pm.addNestedPass(mlir::odml::createFuseConvolutionPass()); pm.addNestedPass( mlir::mhlo::createLegalizeTorchIndexSelectToGatherPass()); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h index c7e191796031f4..740dca6c7b106b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/core/public/session.h" namespace tensorflow { @@ -58,6 +59,8 @@ inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, /*deserialize_xla_call_module=*/false); } +void AddTFToStablehloPasses(mlir::PassManager& pm); + } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir new file mode 100644 index 00000000000000..f4ef2e0f1d26f8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/add_dump_tensor_op_stablehlo.mlir @@ -0,0 +1,76 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-add-dump-tensor-op='debugger_type=int_per_layer' | FileCheck --check-prefix=IntPerLayer %s + +module { + func.func @matmul2(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor) { + %0 = stablehlo.constant dense<[-0.211145893, -0.708605706]> : tensor<2xf32> + %1 = stablehlo.constant dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + %3 = "tf.XlaCallModule"(%2, %1, %0) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor + return %3 : tensor + } + func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.constant dense<6.000000e+00> : tensor + %2 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<2x2xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<2xf32>, tensor<2xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + %6 = stablehlo.clamp %0, %5, %1 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } + func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> tensor attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.constant dense<6.000000e+00> : tensor + %2 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<2x2xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<2xf32>, tensor<2xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + %6 = stablehlo.clamp %0, %5, %1 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } + +// IntPerLayer-LABEL: func @matmul2 +// IntPerLayer-DAG: %[[b0:.*]] = stablehlo.constant dense<[-0.211145893 +// IntPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 +// IntPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_2", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_2", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: %[[matmul1_q:.*]] = "tf.XlaCallModule"(%[[matmul0_q]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: %[[matmul1_uq:.*]] = "tf.XlaCallModule"(%[[matmul0_q]], %[[w0]], %[[b0]]) <{Sout = [#tf_type.shape], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0, _original_entry_function = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0"} : (tensor, tensor<2x2xf32>, tensor<2xf32>) -> tensor +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul1_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_with_bias_and_relu6_dynamic_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_with_bias_and_relu6_dynamic_fn_1", node_name = "_empty_node"}> : (tensor) -> () +// IntPerLayer-DAG: return %[[matmul1_q]] : tensor +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2 +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_2_0 +// IntPerLayer-DAG: func.func private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1_0 +} + +// ----- + +module { + func.func @matmul_concat(%arg0: tensor<1x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x3xf32>) { + %0 = stablehlo.constant dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706]]> : tensor<2x3xf32> + %1 = stablehlo.constant dense<1.000000e+00> : tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %3 = stablehlo.concatenate %2, %1, dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x3xf32> + return %3 : tensor<2x3xf32> + } + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + +// IntPerLayer-LABEL: func @matmul_concat +// IntPerLayer-DAG: %[[w0:.*]] = stablehlo.constant dense<{{\[\[}}-0.630731344 +// IntPerLayer-DAG: %[[c0:.*]] = stablehlo.constant dense<1.000000e+00 +// IntPerLayer-DAG: %[[matmul0_q:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_q]]) <{enabled = false, file_name = "quantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// IntPerLayer-DAG: %[[matmul0_uq:.*]] = "tf.XlaCallModule"(%arg0, %[[w0]]) <{Sout = [#tf_type.shape<1x3>], module = "", version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1_0, _original_entry_function = "composite_dot_general_fn_1_0", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[matmul0_uq]]) <{enabled = false, file_name = "unquantized_tensor_data.pb", func_name = "composite_dot_general_fn_1", log_dir_path = "/tmp/dumps/composite_dot_general_fn_1", node_name = "_empty_node"}> : (tensor<1x3xf32>) -> () +// IntPerLayer-DAG: %[[concat:.*]] = stablehlo.concatenate %[[matmul0_q]], %[[c0]], dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x3xf32> +// IntPerLayer-DAG: return %[[concat]] : tensor<2x3xf32> +// IntPerLayer-DAG: func.func private @composite_dot_general_fn_1 +// IntPerLayer-DAG: func.func private @composite_dot_general_fn_1_0 +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir index 1e771e2586a61e..772b38f56e242b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir @@ -359,3 +359,43 @@ func.func @depthwise_conv2d_with_large_weight_and_add(%arg0: tensor<*xf32>) -> ( // CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) // CHECK-NEXT: return %[[BIASADD]] + +// ---- + +func.func @fuse_conv2d_with_sub_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-0.0800000056> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] + +// ----- + +func.func @fuse_conv2d_with_sub_mul_addv2(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_2 = "tf.Const"() {value = dense<0.300000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %3 = "tf.AddV2"(%2, %cst_2) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %3 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_mul_addv2 +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.200000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] diff --git a/tensorflow/compiler/mlir/register_common_dialects.cc b/tensorflow/compiler/mlir/register_common_dialects.cc index 4cda39bdbb6745..b089bd9a1eb787 100644 --- a/tensorflow/compiler/mlir/register_common_dialects.cc +++ b/tensorflow/compiler/mlir/register_common_dialects.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/service/cpu/hlo_xla_runtime_pipeline.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index e23afc78e6de3b..3bc7791bf2f477 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -71,14 +71,10 @@ 'flatbuffer_to_string', 'flatbuffer_translate', 'hlo_to_kernel', - 'hlo_to_llvm_ir', - 'ifrt-opt', 'json_to_flatbuffer', 'kernel-gen-opt', 'lhlo-tfrt-opt', - 'mlir-bisect', 'mlir-hlo-opt', - 'mlir-interpreter-runner', 'mlir-opt', 'mlir-tflite-runner', 'mlir-translate', @@ -99,13 +95,6 @@ 'tfg-transforms-opt', 'tfg-translate', 'tfjs-opt', - 'xla-cpu-opt', - 'xla-gpu-opt', - 'xla-mlir-gpu-opt', - 'xla-runtime-opt', - 'xla-translate', - 'xla-translate-gpu-opt', - 'xla-translate-opt', ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 5541e2fb580b8a..dc75547758e11f 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -46,20 +46,7 @@ 'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/tfrt', 'tensorflow/compiler/mlir/tools/kernel_gen', - 'tensorflow/compiler/mlir/xla', - os.path.join(external_srcdir, 'local_xla/xla/mlir/backends/cpu'), - os.path.join(external_srcdir, 'local_xla/xla/mlir/backends/gpu'), - os.path.join(external_srcdir, 'local_xla/xla/mlir/runtime'), - os.path.join(external_srcdir, 'local_xla/xla/mlir/tools/mlir_bisect'), os.path.join(external_srcdir, 'local_xla/xla/mlir_hlo'), - os.path.join(external_srcdir, 'local_xla/xla/python/ifrt/ir/tests'), - os.path.join(external_srcdir, 'local_xla/xla/service/gpu/tests'), - os.path.join(external_srcdir, 'local_xla/xla/service/mlir_gpu'), - os.path.join(external_srcdir, 'local_xla/xla/translate'), - os.path.join( - external_srcdir, - 'local_xla/xla/translate/mhlo_to_lhlo_with_xla' - ), 'tensorflow/core/ir/importexport/', 'tensorflow/core/ir/tests/', 'tensorflow/core/transforms/', diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc index 5ce3ccf0cc2c8d..5d9c1d32a92a10 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -40,6 +41,10 @@ limitations under the License. namespace mlir { namespace TF { +namespace { +constexpr char kCompositeDevice[] = "tf._composite_device"; +} // namespace + ResourceConstructingOps::ResourceConstructingOps(Operation *op) { if (op) ops.insert(op); } @@ -57,7 +62,11 @@ ResourceConstructingOps ResourceConstructingOps::getPessimisticValueState( auto global_tensor = tf_saved_model::LookupBoundInputOfType< tf_saved_model::GlobalTensorOp>(func, barg.getArgNumber(), symbol_table); - return ResourceConstructingOps(global_tensor); + ResourceConstructingOps result(global_tensor); + if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { + result.is_on_composite_device = true; + } + return result; } } else if (auto vh = dyn_cast(value.getDefiningOp())) { return ResourceConstructingOps(vh); @@ -74,17 +83,24 @@ ResourceConstructingOps ResourceConstructingOps::join( ResourceConstructingOps ret; ret.ops.insert(lhs.ops.begin(), lhs.ops.end()); ret.ops.insert(rhs.ops.begin(), rhs.ops.end()); + ret.is_on_composite_device = + lhs.is_on_composite_device || rhs.is_on_composite_device; return ret; } void ResourceConstructingOps::print(raw_ostream &os) const { - llvm::interleaveComma(ops, os << "["), os << "]"; + llvm::interleaveComma(ops, os << "["); + if (is_on_composite_device) { + os << " COMPOSITE"; + } + os << "]"; } void ResourceDataflowAnalysis::visitOperation(Operation *op, ArrayRef operands, ArrayRef results) { LLVM_DEBUG(llvm::dbgs() << "ResAn: Visiting operation: " << *op << "\n"); + if (auto cast = dyn_cast(op)) { join(results[0], *operands[0]); } else if (auto while_op = dyn_cast(op)) { @@ -94,6 +110,30 @@ void ResourceDataflowAnalysis::visitOperation(Operation *op, join(getLatticeElement(arg), *getLatticeElement(value)); } } + } else if (auto while_op = dyn_cast(op)) { + func::FuncOp cond = SymbolTable::lookupNearestSymbolFrom( + while_op, while_op.getCondAttr()); + func::FuncOp body = SymbolTable::lookupNearestSymbolFrom( + while_op, while_op.getBodyAttr()); + for (auto &arg : while_op->getOpOperands()) { + BlockArgument cond_arg = cond.getArgument(arg.getOperandNumber()); + join(getLatticeElement(cond_arg), *getLatticeElement(arg.get())); + BlockArgument body_arg = body.getArgument(arg.getOperandNumber()); + join(getLatticeElement(body_arg), *getLatticeElement(arg.get())); + } + } else if (auto graph = dyn_cast(op)) { + for (auto &arg : graph.GetFetch()->getOpOperands()) { + if (arg.getOperandNumber() < graph.getNumResults()) { + auto result = graph.getResult(arg.getOperandNumber()); + join(getLatticeElement(result), *getLatticeElement(arg.get())); + } + } + } else if (auto island = dyn_cast(op)) { + for (auto &arg : island.GetYield()->getOpOperands()) { + auto result = island.getResult(arg.getOperandNumber()); + join(getLatticeElement(result), *getLatticeElement(arg.get())); + // getLatticeElement(arg.get())->print(llvm::errs()); + } } else { setAllToEntryStates(results); } diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h index 68f6fa2d44c763..61fdb0c39f0693 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h @@ -42,7 +42,8 @@ struct ResourceConstructingOps { static ResourceConstructingOps getPessimisticValueState(MLIRContext *context); static ResourceConstructingOps getPessimisticValueState(Value value); bool operator==(const ResourceConstructingOps &rhs) const { - return ops == rhs.ops; + return ops == rhs.ops && + is_on_composite_device == rhs.is_on_composite_device; } static ResourceConstructingOps join(const ResourceConstructingOps &lhs, @@ -52,6 +53,8 @@ struct ResourceConstructingOps { // The operation(s) which created the resource value. // IR constructs (i.e., GlobalTensorOp) are not const-correct. mutable DenseSet ops; + + bool is_on_composite_device = false; }; class ResourceDataflowAnalysis diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index b0d730898316d5..c95dd020497385 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -194,7 +194,7 @@ void CategorizeParallelIdsMap( groups_different_branch = 0; groups_from_only = 0; groups_to_only = 0; - for (auto [group, branch] : from) { + for (const auto& [group, branch] : from) { auto to_iter = to.find(group); if (to_iter == to.end()) { ++groups_from_only; @@ -207,7 +207,7 @@ void CategorizeParallelIdsMap( } } } - for (auto [group, _] : to) { + for (const auto& [group, _] : to) { auto from_iter = from.find(group); if (from_iter == from.end()) { ++groups_to_only; @@ -246,13 +246,13 @@ void SideEffectAnalysisInfo::SetLastWrites( void SideEffectAnalysisInfo::Enter() { per_resource_access_info_.clear(); - for (auto [resource, last_writes] : stack_down_.back()) { + for (const auto& [resource, last_writes] : stack_down_.back()) { SetLastWrites(resource, last_writes); } } void SideEffectAnalysisInfo::Exit() { - for (auto [resource, _] : per_resource_access_info_) { + for (const auto& [resource, _] : per_resource_access_info_) { absl::flat_hash_set last_writes = GetLastWrites(resource); auto& resource_to_operations = stack_up_.back(); resource_to_operations.try_emplace(resource); @@ -265,7 +265,7 @@ void SideEffectAnalysisInfo::Exit() { void SideEffectAnalysisInfo::Down() { stack_down_.emplace_back(); stack_up_.emplace_back(); - for (auto [resource, _] : per_resource_access_info_) { + for (const auto& [resource, _] : per_resource_access_info_) { absl::flat_hash_set last_writes = GetLastWrites(resource); stack_down_.back()[resource] = last_writes; } @@ -279,7 +279,7 @@ void SideEffectAnalysisInfo::Lateral() { void SideEffectAnalysisInfo::Up() { Exit(); - for (auto [resource, last_writes] : stack_up_.back()) { + for (const auto& [resource, last_writes] : stack_up_.back()) { SetLastWrites(resource, last_writes); } stack_down_.pop_back(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index a0549858ffd111..5d145c85a68a06 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -28,6 +28,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/SMLoc.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -37,10 +38,12 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project @@ -64,9 +67,10 @@ struct TFInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// - // Allow all call operations to be inlined. + // Returns whether it's legal to inline a call to a function. bool isLegalToInline(Operation* call, Operation* callable, bool wouldBeCloned) const final { + if (isa(call)) return false; return true; } @@ -77,10 +81,8 @@ struct TFInlinerInterface : public DialectInlinerInterface { return true; } - // Defines the legality of inlining TF Device operations. - bool isLegalToInline(Operation*, Region*, bool, - IRMapping&) const final { - // For now, enable inlining all operations. + // Defines the legality of inlining TF Device operations into a region. + bool isLegalToInline(Operation* call, Region*, bool, IRMapping&) const final { return true; } @@ -130,6 +132,27 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) addInterfaces(); } +//===----------------------------------------------------------------------===// +// tf_device.cluster_func +//===----------------------------------------------------------------------===// + +LogicalResult ClusterFuncOp::verifySymbolUses( + mlir::SymbolTableCollection& symbolTable) { + StringAttr func_attr = getFuncAttr().getRootReference(); + func::FuncOp func = + symbolTable.lookupNearestSymbolFrom(*this, func_attr); + if (!func) { + return emitError("'func' attribute refers to an undefined function: ") + << func_attr.getValue(); + } + return success(); +} + +void ClusterFuncOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + SymbolRefAttr calleeAttr = callee.get(); + return setFuncAttr(cast(calleeAttr)); +} + //===----------------------------------------------------------------------===// // tf_device.launch //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 343127301d4057..74a12cfa1e22db 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -19,8 +19,10 @@ limitations under the License. #define TF_DEVICE_DIALECT include "mlir/IR/OpBase.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// // TensorFlow Device Dialect definitions @@ -335,7 +337,8 @@ used to form the cluster. let hasCanonicalizer = 1; } -def TfDevice_ClusterFuncOp : TfDevice_Op<"cluster_func", []> { +def TfDevice_ClusterFuncOp : TfDevice_Op<"cluster_func", + [CallOpInterface, DeclareOpInterfaceMethods]> { let summary = [{ The `tf_device.cluster_func` launches a function containing the body of a cluster. @@ -347,7 +350,7 @@ This op is used for outlining a cluster. let arguments = (ins FlatSymbolRefAttr:$func, - Variadic + Variadic:$args ); let results = (outs @@ -355,10 +358,19 @@ This op is used for outlining a cluster. ); let extraClassDeclaration = [{ + // Gets the argument operands to the called function. + operand_range getArgOperands() { return getArgs(); } + MutableOperandRange getArgOperandsMutable() { + return getArgsMutable(); + } // returns the function that this operation will launch. func::FuncOp getFuncOp() { return SymbolTable::lookupNearestSymbolFrom(*this, getFuncAttr()); } + CallInterfaceCallable getCallableForCallee() { + return getFuncAttr(); + } + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee); }]; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 71cf2490e911f4..5e0e58c279e358 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -12463,6 +12463,18 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." let hasFolder = 1; } +def TF_ReadFileOp : TF_Op<"ReadFile", [Pure, TF_NoConstantFold]> { + let summary = "Reads and outputs the entire contents of the input filename."; + + let arguments = (ins + TF_StrTensor:$filename + ); + + let results = (outs + TF_StrTensor:$contents + ); +} + def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> { let summary = "Reads the value of a variable."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc index 879aa62ab28f03..b2ae51a1189686 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc @@ -90,7 +90,7 @@ mlir::LogicalResult PwStreamResultsOp::verify() { } //===----------------------------------------------------------------------===// -// IfrtProgramCall +// IfrtCall //===----------------------------------------------------------------------===// mlir::LogicalResult IfrtCallOp::verify() { @@ -115,6 +115,26 @@ mlir::LogicalResult IfrtCallOp::verify() { } } + // Verify variable_arg_indices is sorted in ascending order. + int64_t prev_index = -1; + for (auto arg_index_attr : getVariableArgIndicesAttr()) { + if (!arg_index_attr.isa_and_nonnull()) { + return emitOpError() << "variable_arg_indices must be an integer"; + } + + int64_t index = + arg_index_attr.dyn_cast().getValue().getSExtValue(); + if (index < 0) { + return emitOpError() << "variable_arg_indices must be positive"; + } + + if (index <= prev_index) { + return emitOpError() + << "variable_arg_indices must be sorted in ascending order"; + } + prev_index = index; + } + return mlir::success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td index af4bdcea69182e..fe230904c241be 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.td @@ -63,7 +63,7 @@ def TF__TfrtGetResourceOp : TF_Op<"_TfrtGetResource", let hasVerifier = 1; } -def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", []> { +def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", [Pure]> { let summary = "Loads a variable tensor as an IFRT array"; let description = [{ @@ -77,6 +77,9 @@ def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", []> { `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding configuration specified in `VariableDeviceShardingConfigProto`. + + This op returns a scalar string tensor containing the loaded variable name, which can be + used as a key to look for the loaded IFRT array in runtime. }]; let arguments = (ins @@ -85,7 +88,12 @@ def TF_IfrtLoadVariableOp : TF_Op<"IfrtLoadVariable", []> { DefaultValuedAttr:$name ); + let results = (outs + TF_StrTensor:$array_key + ); + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; } @@ -101,14 +109,13 @@ def TF_IfrtCallOp : TF_Op<"IfrtCall", []> { that the outlined function is compiled into an executable and is available for lookup from `IfrtCall` TF ops. - This op also takes `variable_names` attribute to bind the variables (weights) - by names. + `variable_arg_indices` is a sorted (ascending order) array and indicates which + element of `args` is a key to a loaded array corresponding to a variable. }]; let arguments = (ins Variadic : $args, I64Attr : $program_id, - StrArrayAttr : $variable_names, I32ArrayAttr : $variable_arg_indices ); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index be0ab858484acd..05e5105883f94b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -752,3 +752,14 @@ func.func @testGlobalIterIdNotFolded() -> (tensor) { // CHECK: return %[[X]] func.return %0: tensor } + +// ----- + +// CHECK-LABEL: func @testReadFileOpNotFolded +func.func @testReadFileOpNotFolded() -> (tensor) { + %0 = "tf.Const"() { value = dense<"filepath"> : tensor } : () -> tensor + // CHECK: %[[X:.*]] = "tf.ReadFile" + %1 = "tf.ReadFile"(%0) : (tensor) -> tensor + // CHECK: return %[[X]] + func.return %1: tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir index d473c1a7d7b67e..5162a2df7366e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -tf-executor-convert-control-to-data-outputs -split-input-file %s | FileCheck %s +// RUN: tf-opt %s -pass-pipeline='builtin.module(tf-executor-convert-control-to-data-outputs{composite-tpuexecute-side-effects})' -split-input-file -verify-diagnostics | FileCheck %s !tf_res = tensor>> @@ -574,3 +574,227 @@ func.func @unconnected(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) { } func.return } + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @tpu_execute_while_body +func.func @tpu_execute_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe:%.*]] = tf_executor.island({{.*}}) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg0, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1]} : (!tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK-DAG: [[exe]]{{.*}}"tf.Identity"(%arg3) + // CHECK-DAG: "tf.Identity"(%arg4) + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @tpu_execute_while_cond +func.func @tpu_execute_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @tpu_execute +func.func @tpu_execute(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @tpu_execute_while_body, + cond = @tpu_execute_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @incomplete_composite_devices_while_body +func.func @incomplete_composite_devices_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe:%.*]] = tf_executor.island({{.*}}) wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1]} : (!tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: [[exe]]{{.*}}"tf.Identity"(%arg3) + // CHECK-NOT: "tf.Identity" + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @incomplete_composite_devices_while_cond +func.func @incomplete_composite_devices_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @incomplete_composite_devices +func.func @incomplete_composite_devices(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @incomplete_composite_devices_while_body, + cond = @incomplete_composite_devices_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @complete_composite_devices_while_body +func.func @complete_composite_devices_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe:%.*]] = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1]} : (!tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: "tf.Identity"(%arg3) + // CHECK: "tf.Identity"(%arg4) + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @complete_composite_devices_while_cond +func.func @complete_composite_devices_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @complete_composite_devices +func.func @complete_composite_devices( + %arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @complete_composite_devices_while_body, + cond = @complete_composite_devices_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: @tpu_execute_with_non_resource_operands_while_body +func.func @tpu_execute_with_non_resource_operands_while_body(%arg0: !tf_res, %arg1: !tf_res, + %arg2: tensor) + -> (!tf_res, !tf_res, tensor) { + %graph:3 = tf_executor.graph { + %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str + // CHECK: [[exe:%.*]] = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables" + %exe_control = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { + device_var_reads_indices = [1, 2], + device_var_updates_indices = [1, 2]} : (tensor, !tf_res, !tf_res, !tf_str) -> () + + %assign_control_0 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor>>, tensor) -> () + %assign_control_1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor>>, tensor) -> () + %add, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg2) : (tensor, tensor) -> tensor + %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor + %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () + // CHECK: "tf.Identity"(%arg3) + // CHECK: "tf.Identity"(%arg4) + // CHECK: tf_executor.fetch + tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control + } + func.return %graph#0, %graph#1, %graph#2 : !tf_res, !tf_res, tensor +} + +// CHECK-LABEL: @tpu_execute_with_non_resource_operands_while_cond +func.func @tpu_execute_with_non_resource_operands_while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"(%arg2) : (tensor) -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @tpu_execute_with_non_resource_operands +func.func @tpu_execute_with_non_resource_operands(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, + %arg1: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}, %arg2: tensor) { + tf_executor.graph { + // CHECK: "tf.Const"{{.*}}tensor + %while_out:3, %control_while = tf_executor.island wraps "tf.While"(%arg0, %arg1, %arg2) + {body = @tpu_execute_with_non_resource_operands_while_body, + cond = @tpu_execute_with_non_resource_operands_while_cond, is_stateless = false} + : (tensor>>, tensor>>, tensor) + -> (tensor>>, tensor>>, tensor) + tf_executor.fetch + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir index 9c0eeaaf33e782..9c3202f2dc4945 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir @@ -495,9 +495,9 @@ func.func @decompose_resource_gather_op(%indices : tensor) -> tensor<*xi3 %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource>> // CHECK-DAG: [[READVAR:%.+]] = "tf.ReadVariableOp"([[VAR]]) - // CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) <{batch_dims = 0 : i64}> {_xla_outside_compilation = "0"} : (tensor<*xi32>, tensor, tensor) -> tensor<*xi32> + // CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) <{batch_dims = 0 : i64}> : (tensor<*xi32>, tensor, tensor) -> tensor<*xi32> // CHECK: return [[GATHER]] - %1 = "tf.ResourceGather"(%resource, %indices) {_xla_outside_compilation = "0"} : (tensor<*x!tf_type.resource>>, tensor) -> (tensor<*xi32>) + %1 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf_type.resource>>, tensor) -> (tensor<*xi32>) tf_device.return %1 : tensor<*xi32> }) : () -> (tensor<*xi32>) func.return %0: tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir index 90ffba031e12de..f4a9e62c3b80c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/end-to-end-tpu-reshard-variables.mlir @@ -12,55 +12,51 @@ // CHECK: TPUExecuteAndUpdateVariables // CHECK: tf_device.replicate // CHECK: TPUReshardVariables -"builtin.module"() ({ - "func.func"() ({ - ^bb0(%res : tensor<*x!tf_type.resource>): - "tf_executor.graph"() ({ - %ctrl = "tf_executor.island"() ({ - "tf.StatefulPartitionedCall"(%res) {config = "", config_proto = "", executor_type = "", f = @partitioned} : (tensor<*x!tf_type.resource>) -> () - "tf_executor.yield"() : () -> () - }) : () -> (!tf_executor.control) - "tf_executor.fetch"(%ctrl) : (!tf_executor.control) -> () - }) : () -> () - "func.return"() : () -> () - }) {arg_attrs = [{tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}], function_type = (tensor<*x!tf_type.resource>) -> (), sym_name = "main"} : () -> () - "func.func"() ({ - ^bb0(%res : tensor<*x!tf_type.resource>): - "tf_executor.graph"() ({ - %ctrl = "tf_executor.island"() ({ - %w = "tf.While"(%res) {body = @while_body, cond = @while_cond, is_stateless = false, shape_invariant} : (tensor<*x!tf_type.resource>) -> (tensor>>) - "tf_executor.yield"() : () -> () - }) : () -> (!tf_executor.control) - "tf_executor.fetch"(%ctrl) : (!tf_executor.control) -> () - }) : () -> () - "func.return"() : () -> () - }) {arg_attrs = [{tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}], function_type = (tensor<*x!tf_type.resource>) -> (), sym_name = "partitioned", sym_visibility = "private"} : () -> () - "func.func"() ({ - ^bb0(%res : tensor>>): - %g = "tf_executor.graph"() ({ - %i:2 = "tf_executor.island"() ({ - "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = true, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 2 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> () - %one = "tf.Const"() {_tpu_replicate = "cluster", value = dense<1> : tensor} : () -> tensor<*xi32> - %res_rep = "tf.TPUReplicatedInput"(%res) {index = -1 : i64, is_mirrored_variable = true, is_packed = true} : (tensor>>) -> tensor>> - %read = "tf.ReadVariableOp"(%res_rep) {_tpu_replicate = "cluster", device = ""} : (tensor>>) -> tensor<*xi32> - %inc = "tf.Add"(%read, %one) {_tpu_replicate = "cluster"} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - "tf.AssignVariableOp"(%res_rep, %inc) {_tpu_replicate = "cluster"} : (tensor>>, tensor<*xi32>) -> () - %res_out:2 = "tf.TPUReplicatedOutput"(%res_rep) : (tensor>>) -> (tensor>>, tensor>>) - "tf_executor.yield"(%res) : (tensor>>) -> () - }) : () -> (tensor>>, !tf_executor.control) - "tf_executor.fetch"(%i#0) : (tensor>>) -> () - }) : () -> tensor>> - "func.return"(%g) : (tensor>>) -> () - }) {arg_attrs = [{tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}], function_type = (tensor>>) -> tensor>>, sym_name = "while_body", sym_visibility = "private"} : () -> () - "func.func"() ({ - ^bb0(%res : tensor>>): - %g = "tf_executor.graph"() ({ - %i:2 = "tf_executor.island"() ({ - %c = "tf.Const"() {value = dense<0> : tensor} : () -> tensor<*xi1> - "tf_executor.yield"(%c) : (tensor<*xi1>) -> () - }) : () -> (tensor<*xi1>, !tf_executor.control) - "tf_executor.fetch"(%i#0) : (tensor<*xi1>) -> () - }) : () -> tensor<*xi1> - "func.return"(%g) : (tensor<*xi1>) -> () - }) {arg_attrs = [{tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}], function_type = (tensor>>) -> (tensor<*xi1>), sym_name = "while_cond", sym_visibility = "private"} : () -> () -}) {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" = {}, "/job:tpu_host_worker/replica:0/task:0/device:TPU:0" = {}, "/job:tpu_host_worker/replica:0/task:0/device:TPU:1" = {}, "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1163 : i32}} : () -> () +module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" = {}, "/job:tpu_host_worker/replica:0/task:0/device:TPU:0" = {}, "/job:tpu_host_worker/replica:0/task:0/device:TPU:1" = {}, "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1163 : i32}} { + func.func @main(%arg0: tensor<*x!tf_type.resource> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) { + tf_executor.graph { + %control = tf_executor.island { + "tf.StatefulPartitionedCall"(%arg0) <{config = "", config_proto = "", executor_type = "", f = @partitioned}> : (tensor<*x!tf_type.resource>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return + } + func.func private @partitioned(%arg0: tensor<*x!tf_type.resource> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) { + tf_executor.graph { + %control = tf_executor.island { + %0 = "tf.While"(%arg0) <{body = @while_body, cond = @while_cond, is_stateless = false, shape_invariant}> : (tensor<*x!tf_type.resource>) -> tensor>> + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return + } + func.func private @while_body(%arg0: tensor>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> tensor>> { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island { + "tf.TPUReplicateMetadata"() <{allow_soft_placement = true, computation_shape = [], device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 2 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = false, use_tpu = true}> {_tpu_replicate = "cluster", device = ""} : () -> () + %cst = "tf.Const"() <{value = dense<1> : tensor}> {_tpu_replicate = "cluster"} : () -> tensor<*xi32> + %1 = "tf.TPUReplicatedInput"(%arg0) <{index = -1 : i64, is_mirrored_variable = true, is_packed = true}> : (tensor>>) -> tensor>> + %2 = "tf.ReadVariableOp"(%1) {_tpu_replicate = "cluster", device = ""} : (tensor>>) -> tensor<*xi32> + %3 = "tf.Add"(%2, %cst) {_tpu_replicate = "cluster"} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + "tf.AssignVariableOp"(%1, %3) {_tpu_replicate = "cluster"} : (tensor>>, tensor<*xi32>) -> () + %4:2 = "tf.TPUReplicatedOutput"(%1) : (tensor>>) -> (tensor>>, tensor>>) + tf_executor.yield %arg0 : tensor>> + } + tf_executor.fetch %outputs : tensor>> + } + return %0 : tensor>> + } + func.func private @while_cond(%arg0: tensor>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> tensor<*xi1> { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island { + %cst = "tf.Const"() <{value = dense : tensor}> : () -> tensor<*xi1> + tf_executor.yield %cst : tensor<*xi1> + } + tf_executor.fetch %outputs : tensor<*xi1> + } + return %0 : tensor<*xi1> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir new file mode 100644 index 00000000000000..c88310443ed0d1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/hoist_broadcast_read.mlir @@ -0,0 +1,68 @@ +// RUN: tf-opt %s -split-input-file -tf-hoist-broadcast-read | FileCheck %s + +// The read should be hoisted. + +// CHECK-LABEL: func @hoist_cpu +func.func @hoist_cpu(%arg0: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp" + // CHECK-NEXT: tf_device.replicate + // CHECK-NEXT: "tf.OpA"(%[[READ]]) + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0) : (tensor) -> () + } + func.return +} + +// ----- + +// The read should not be hoisted because the resource does not have device type CPU. + +// CHECK-LABEL: func @only_hoist_cpu +func.func @only_hoist_cpu(%arg0: tensor<*x!tf_type.resource>>) -> () { + // CHECK: tf_device.replicate + // CHECK-NEXT: "tf.ReadVariableOp" + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0) : (tensor) -> () + } + func.return +} + +// ----- + +// The read should not be hoisted because it follows a write. + +// CHECK-LABEL: func @skip_read_after_write +func.func @skip_read_after_write(%arg0: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: tf_device.replicate + // CHECK: "tf.AssignVariableOp" + // CHECK-NEXT: "tf.ReadVariableOp" + tf_device.replicate {n = 2 : i32} { + %0 = "tf.OpA"() : () -> tensor + "tf.AssignVariableOp"(%arg0, %0) : (tensor<*x!tf_type.resource>>, tensor) -> () + %1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpB"(%1) : (tensor) -> () + } + func.return +} + +// ----- + +// Check that hoisting preserves read order. + +// CHECK-LABEL: func @order_preserved +func.func @order_preserved(%arg0: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>>, %arg2: tensor<*x!tf_type.resource>> {tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> () { + // CHECK: %[[READ0:.*]] = "tf.ReadVariableOp"(%arg0) + // CHECK-NEXT: %[[READ2:.*]] = "tf.ReadVariableOp"(%arg2) + // CHECK-NEXT: tf_device.replicate + // CHECK-NEXT: %[[READ1:.*]] = "tf.ReadVariableOp"(%arg1) + // CHECK-NEXT: "tf.OpA"(%[[READ0]], %[[READ1]], %[[READ2]]) + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + %1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor + %2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0, %1, %2) : (tensor, tensor, tensor) -> () + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir index 021cad3b78be8f..389a682d3afe46 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir @@ -167,9 +167,25 @@ func.func @UnsupportedOp(%arg0: tensor) -> tensor { // _XlaHostComputeMlir with manual_sharding should not fall back to // XlaHostCompute, because XlaHostCompute does not support manual_sharding. +// Instead, it is skipped and the MlirXlaOpKernel is expected to handle it. func.func @HostComputeManualNoFallback(%arg0: tensor) -> () { - // expected-error @+1 {{manual_sharding not supported with fallback}} + // CHECK: "tf._XlaHostComputeMlir" %1 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv1", send_key = "host_compute_channel_send1", host_mlir_module = "", manual_sharding = true} : (tensor) -> (tensor) func.return } + +// ----- + +// CHECK-LABEL: test_xla_call_module_with_host_communicative_subcomputation +func.func @test_xla_call_module_with_host_communicative_subcomputation() { + "tf.XlaCallModule"() {Sout = [], device = "", dim_args_spec = [], function_list = [@callee], module = "", platforms = [], version = 4 : i64} : () -> () + func.return +} + +// CHECK-LABEL: callee +func.func private @callee(%arg0: tensor) { + "tf.XlaHostCompute"(%arg0) <{ancestors = [], key = "@host_func", recv_key = "", send_key = "", shapes = []}> {_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"]} : (tensor) -> () + return + } + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir index eff3e38ab5ace2..479df14b6546f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir @@ -960,3 +960,59 @@ func.func @testGeneratorDatasetRegionWithComplexBlocks(%arg0: tensor<4xf32>, %ar }) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", metadata = "", operandSegmentSizes = array, output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string]} : (tensor<4xf32>, tensor<3xf32>, tensor, tensor<2xf32>) -> tensor return } + +// ----- + +func.func private @tf.WhileRegion_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %1 = builtin.unrealized_conversion_cast to tensor + %2 = builtin.unrealized_conversion_cast to tensor + return %1, %2 : tensor, tensor +} +func.func private @tf.WhileRegion_cond(%arg0: tensor) -> tensor { + %0 = builtin.unrealized_conversion_cast to tensor + return %0 : tensor +} +// CHECK-LABEL: testNameCollision +func.func @testNameCollision(%arg0: tensor) { + %1 = builtin.unrealized_conversion_cast to tensor + %2 = builtin.unrealized_conversion_cast to tensor + // CHECK: "tf.While" + // CHECK-SAME: body = @tf.WhileRegion_body_1 + // CHECK-SAME: cond = @tf.WhileRegion_cond_0 + %3:2 = "tf.WhileRegion"(%1, %2) <{is_stateless = false}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = func.call @tf.WhileRegion_cond(%arg1) : (tensor) -> tensor + "tf.Yield"(%8, %arg1, %arg2) : (tensor, tensor, tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8:2 = func.call @tf.WhileRegion_body(%arg1, %arg2) : (tensor, tensor) -> (tensor, tensor) + "tf.Yield"(%8#0, %8#1) : (tensor, tensor) -> () + }) : (tensor, tensor) -> (tensor, tensor) + return +} + +// ----- + +func.func private @my_cond(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = builtin.unrealized_conversion_cast to tensor + return %0 : tensor +} +func.func private @my_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + return %arg0, %arg1 : tensor, tensor +} +// CHECK-LABEL: testConditionWithPassthroughArgs +func.func @testConditionWithPassthroughArgs(%arg1: tensor, %arg2: tensor) { + // CHECK: "tf.While" + // CHECK-SAME: body = @my_body + // CHECK-SAME: cond = @my_cond + %3:2 = "tf.WhileRegion"(%arg1, %arg2) <{is_stateless = false}> ({ + ^bb0(%barg1: tensor, %barg2: tensor): + %8 = func.call @my_cond(%barg1, %barg2) : (tensor, tensor) -> tensor + "tf.Yield"(%8, %barg1, %barg2) : (tensor, tensor, tensor) -> () + }, { + ^bb0(%barg1: tensor, %barg2: tensor): + %r1, %r2 = func.call @my_body(%barg1, %barg2) : (tensor, tensor) -> (tensor, tensor) + "tf.Yield"(%r1, %r2) : (tensor, tensor) -> () + }) : (tensor, tensor) -> (tensor, tensor) + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir b/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir index e4c1941a1c202f..c3608a2fb13145 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir @@ -175,11 +175,13 @@ func.func @can_remove_all_results(%arg0: f32) -> (f32, f32) { // CHECK-LABEL: @has_inner_function func.func private @has_inner_function(%arg0: f32) -> (f32, f32) { - func.func private @inner() -> (tensor, tensor) { - %0, %1 = "some_constant"() : () -> (tensor, tensor) - // CHECK: return - // CHECK-SAME: tensor, tensor - return %0, %1 : tensor, tensor + builtin.module { + func.func private @inner() -> (tensor, tensor) { + %0, %1 = "some_constant"() : () -> (tensor, tensor) + // CHECK: return + // CHECK-SAME: tensor, tensor + return %0, %1 : tensor, tensor + } } // CHECK: return // CHECK-NOT: arg diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 034592fe138580..f7bac4ba31b50a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1436,6 +1436,15 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %0 : tensor<*xi32> } + // CHECK-LABEL: func @xla_gather_with_fold + // CHECK-SAME: (%arg0: tensor<1x1x9xi32>, %arg1: tensor<1xi32>) -> tensor<1x1x8xi32> + func.func @xla_gather_with_fold(%arg0: tensor<1x1x9xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { + %cst = "tf.Const"() {value = dense<[1, 1, 8]> : tensor<3xi32>} : () -> tensor<3xi32> + %slice_size = "tf.Identity"(%cst) : (tensor<3xi32>) -> tensor<3xi32> + %0 = "tf.XlaGather"(%arg0, %arg1, %slice_size) {dimension_numbers = "\0A\03\00\01\02\1A\01\02", indices_are_sorted = true} : (tensor<1x1x9xi32>, tensor<1xi32>, tensor<3xi32>) -> tensor<*xi32> + func.return %0 : tensor<*xi32> + } + // CHECK: func private @sum_reducer3(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-NEXT: %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor // CHECK-NEXT: return %0 : tensor @@ -2256,4 +2265,4 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %3#1, %3#2, %4, %5 : tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32> } -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir index fc6ef646e340dd..4632eea672bb78 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s | tf-opt | FileCheck %s +// RUN: tf-opt --verify-diagnostics --split-input-file %s | FileCheck %s // CHECK-LABEL: func @return_no_operands func.func @return_no_operands() { @@ -9,6 +9,8 @@ func.func @return_no_operands() { func.return } +// ----- + // CHECK-LABEL: func @return_one_operand // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>) func.func @return_one_operand(%arg_0: tensor<*xf32>) { @@ -19,6 +21,8 @@ func.func @return_one_operand(%arg_0: tensor<*xf32>) { func.return } +// ----- + // CHECK-LABEL: func @return_multiple_operands // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<*xi32>) func.func @return_multiple_operands(%arg_0: tensor<*xf32>, %arg_1: tensor<*xi32>) { @@ -29,6 +33,8 @@ func.func @return_multiple_operands(%arg_0: tensor<*xf32>, %arg_1: tensor<*xi32> func.return } +// ----- + // CHECK-LABEL: func @empty_replicate func.func @empty_replicate() { tf_device.replicate {n = 2 : i32} { @@ -40,6 +46,8 @@ func.func @empty_replicate() { // CHECK-NEXT: tf_device.return } +// ----- + // CHECK-LABEL: func @no_operand_replicate func.func @no_operand_replicate() { tf_device.replicate {n = 2 : i32} { @@ -53,6 +61,8 @@ func.func @no_operand_replicate() { // CHECK: tf_device.return } +// ----- + // CHECK-LABEL: func @replicate_with_multiple_operands func.func @replicate_with_multiple_operands() { %0 = "tf.opA"() : () -> tensor<*xi1> @@ -95,6 +105,8 @@ func.func @replicate_with_multiple_operands() { // CHECK-NEXT: tf_device.return } +// ----- + // CHECK-LABEL: func @replicate_derived_operandSegmentSizes func.func @replicate_derived_operandSegmentSizes() { tf_device.replicate {n = 2 : i32, operandSegmentSizes = array} { @@ -107,6 +119,8 @@ func.func @replicate_derived_operandSegmentSizes() { // CHECK-NEXT: tf_device.return } +// ----- + // CHECK-LABEL: func @replicate_with_return // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>) func.func @replicate_with_return(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xi32>) { @@ -121,6 +135,8 @@ func.func @replicate_with_return(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %ar // CHECK-NEXT: tf_device.return %[[INPUT_0]], %[[ARG_2]] } +// ----- + // CHECK-LABEL: func @replicate_with_devices func.func @replicate_with_devices() { tf_device.replicate() {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"]}} { @@ -134,6 +150,8 @@ func.func @replicate_with_devices() { // CHECK-NEXT: tf_device.return } +// ----- + // CHECK-LABEL: func @replicate_with_multiple_devices func.func @replicate_with_multiple_devices() { tf_device.replicate() {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"], TPU_REPLICATED_CORE_1 = ["/DEVICE:2", "/DEVICE:3"]}} { @@ -147,6 +165,8 @@ func.func @replicate_with_multiple_devices() { // CHECK-NEXT: tf_device.return } +// ----- + // CHECK-LABEL: func @replicate_with_inner_ops func.func @replicate_with_inner_ops() { %0 = "tf.opA"() : () -> (tensor<*xi1>) @@ -162,6 +182,8 @@ func.func @replicate_with_inner_ops() { func.return } +// ----- + // CHECK-LABEL: func @parallel_execute_two_regions func.func @parallel_execute_two_regions() { "tf_device.parallel_execute"() ({ @@ -173,6 +195,8 @@ func.func @parallel_execute_two_regions() { func.return } +// ----- + // CHECK-LABEL: func @parallel_execute_two_regions_with_ops func.func @parallel_execute_two_regions_with_ops() { "tf_device.parallel_execute"() ({ @@ -187,6 +211,8 @@ func.func @parallel_execute_two_regions_with_ops() { func.return } +// ----- + // CHECK-LABEL: func @parallel_execute_regions_with_data_results func.func @parallel_execute_regions_with_data_results() { "tf_device.parallel_execute"() ({ @@ -200,3 +226,11 @@ func.func @parallel_execute_regions_with_data_results() { }) {} : () -> (tensor<*xi1>, tensor<*xi32>, tensor<*xf32>) func.return } + +// ----- + +func.func @parallel_execute_regions_with_data_results(%arg0: tensor) -> tensor { + // expected-error @+1 {{'func' attribute refers to an undefined function: undefined_func}} + %0 = "tf_device.cluster_func"(%arg0) {func = @undefined_func} : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index 9c0f9b2eddb29b..3e7f029d8ff864 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -161,3 +161,70 @@ module attributes {tf_saved_model.semantics} { // tf_saved_model.semantics. // CHECK-LABEL: module module {} + +// ----- + +!tf_res = tensor>> +!tf_str = tensor<3x!tf_type.string> + +// CHECK-LABEL: module attributes +module attributes {tf_saved_model.semantics} { + +"tf_saved_model.global_tensor"() {sym_name = "v1", type = tensor, value = dense<3.0> : tensor } : () -> () +"tf_saved_model.global_tensor"() {sym_name = "v2", type = tensor, value = dense<2.0> : tensor } : () -> () + +// CHECK-LABEL: @body +func.func private @body(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res) { + %graph:2 = tf_executor.graph { + %value, %value_control = tf_executor.island wraps "tf.GetKey"() : () -> tensor + %ret0, %ret0_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %ret1, %ret1_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %control_unknown = tf_executor.island wraps "tf.UnknownOp"() : () -> () + %key, %key_control = tf_executor.island wraps "tf.GetKey"() : () -> !tf_str + // CHECK: "tf.ReadVariableOp"(%arg0) + %read1, %read1_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor + // CHECK: "tf.ReadVariableOp"(%arg1) + %read2, %read2_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg1) : (!tf_res) -> tensor + tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%ret0, %ret1, %key) { + device_var_reads_indices = [0, 1], + device_var_updates_indices = [0, 1]} : (!tf_res, !tf_res, !tf_str) -> () + tf_executor.fetch %ret0, %ret1: !tf_res, !tf_res + } + func.return %graph#0, %graph#1 : !tf_res, !tf_res +} + +// CHECK-LABEL: @cond +func.func private @cond(%arg0: !tf_res, %arg1: !tf_res) -> (tensor) { + %graph = tf_executor.graph { + %island, %ctrl = tf_executor.island { + %pred = "tf.SomeOp"() : () -> tensor + tf_executor.yield %pred : tensor + } + tf_executor.fetch %island : tensor + } + func.return %graph : tensor +} + +// CHECK-LABEL: @test_while_loop +func.func @test_while_loop(%arg0: !tf_res {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf_saved_model.bound_input = @v1}, + %arg1: !tf_res {tf_saved_model.bound_input = @v2}) + attributes {tf_saved_model.exported_names = ["test_while_loop"]} { + // CHECK-DAG: Const{{.*}}2.0 + // CHECK-DAG: Const{{.*}}3.0 + %read1 = "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor + %read2 = "tf.ReadVariableOp"(%arg1) : (!tf_res) -> tensor + // CHECK: tf_executor.graph + tf_executor.graph { + %handle0, %handle0_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %handle1, %handle1_control = tf_executor.island wraps "tf.SomeOp"() : () -> !tf_res + %control_A = tf_executor.island wraps "tf.OpA"() : () -> () + %while_out:2, %while_control = tf_executor.island(%control_A) wraps "tf.While"( + %handle0, %handle1) { + body = @body, cond = @cond, is_stateless = false + } : (tensor>>, tensor>>) -> (tensor>>, tensor>>) + %control_B = tf_executor.island(%while_control) wraps "tf.OpB"() : () -> () + tf_executor.fetch + } + func.return +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir index 3fb11e56172276..3bd1677bc30baa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tfrt_ops.mlir @@ -13,3 +13,15 @@ func.func @testPwStreamResults(%arg0: tensor, %arg1: tensor) { } // ----- +// CHECK-LABEL: func @test_ifrt_call +func.func @test_ifrt_call(%arg0: tensor, %arg1: tensor) { + %result = "tf.IfrtCall"(%arg0, %arg1) <{program_id = 1234 : i64, variable_arg_indices = [0 : i32, 1 : i32], variable_names = ["a", "b"]}> : (tensor, tensor) -> (tensor<1x1xf32>) + func.return +} + +// ----- +func.func @test_ifrt_call_fail_unsorted_variable_arg_indices(%arg0: tensor, %arg1: tensor) { + // expected-error@below {{variable_arg_indices must be sorted in ascending order}} + %result = "tf.IfrtCall"(%arg0, %arg1) <{program_id = 1234 : i64, variable_arg_indices = [1 : i32, 0 : i32], variable_names = ["a", "b"]}> : (tensor, tensor) -> (tensor<1x1xf32>) + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir index 91e4ff2b714cd2..4d400d437ab28a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir @@ -148,6 +148,8 @@ func.func @resource_missing_subtype(%arg0: tensor, %arg1: ten // ----- +func.func private @computation(%arg0: tensor) -> tensor + func.func @missing_num_cores_per_replica(%arg0: tensor>>) { // expected-error@+1 {{op num cores per replica unavailable}} %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> @@ -159,6 +161,8 @@ func.func @missing_num_cores_per_replica(%arg0: tensor) -> tensor + func.func @mismatch_num_cores_per_replica(%arg0: tensor>>) { // expected-error@+1 {{expects 2 operands but found 3}} %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg0, %arg0) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>, tensor>>) -> tensor>> @@ -243,6 +247,8 @@ func.func @non_replicated_sharding(%arg0: tensor>> // ----- +func.func private @computation(%arg0: tensor) -> tensor + func.func @packed_replicated(%arg0: tensor>> {tf.device = "COMPOSITE"}) { // expected-error@+1 {{support}} %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg0) {_XlaSharding = "", partition_dims = [], is_packed = false} : (tensor>>, tensor>>) -> tensor>> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 9796913d50cc15..55b68e5de2fb5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1334,6 +1334,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor }) : () -> () func.return } + func.func @empty_func() { + func.return + } } // ----- @@ -2447,6 +2450,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf_device.cluster_func"() {_replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () func.return } + func.func @empty_func() { + func.return + } } // ----- @@ -2457,6 +2463,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () func.return } + func.func @empty_func() { + func.return + } } // ----- @@ -2467,6 +2476,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf_device.cluster_func"() {_xla_compile_device_type = "XPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () func.return } + func.func @empty_func() { + func.return + } } // ----- @@ -2477,6 +2489,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor "tf_device.cluster_func"() {_xla_compile_device_type = "XPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> () func.return } + func.func @empty_func() { + func.return + } } // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 468e3495439799..66cae16a67424d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -174,6 +174,72 @@ func.func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf_type.resou // ----- +// Tests sharding propagation in while region body. +// CHECK-LABEL: func @check_sharding_for_read_variable_inside_while_body +func.func @check_sharding_for_read_variable_inside_while_body(%arg0 : tensor, %arg1: tensor<*x!tf_type.resource>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %1:1 = "tf_device.cluster_func"(%arg0, %0) {func = @func_with_sharding_inside_while_body, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", num_cores_per_replica = 2 : i64, use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor, tensor<128x1024xf32>) -> (tensor<128x1024xf32>) + // CHECK: input_sharding_configuration + // CHECK-SAME: ["", "\0A\0B\0C"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\0D\0E\0F"] + func.return +} + +// CHECK-LABEL: func @func_with_sharding_inside_while_body +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.sharding = ""}, %{{[a-z0-9]+}}: tensor<128x1024xf32> {mhlo.sharding = "\0A\0B\0C"}) +// CHECK-SAME: -> (tensor<128x1024xf32> {mhlo.sharding = "\0D\0E\0F"}) +func.func @func_with_sharding_inside_while_body(%arg0: tensor, %arg1: tensor<128x1024xf32>) -> (tensor<128x1024xf32>) { + %cst = "tf.Const"() <{value = dense<0> : tensor}> {device = ""} : () -> tensor + %0:2 = "tf.WhileRegion"(%cst, %arg1) <{is_stateless = false, parallel_iterations = 1 : i64}> ({ + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.Less"(%arg2, %arg0) : (tensor, tensor) -> tensor + "tf.Yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.XlaSharding"(%arg3) <{_XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %2 = "tf.Square"(%1) : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + "tf.Yield"(%arg2, %2) : (tensor, tensor<128x1024xf32>) -> () + }) {_num_original_outputs = 1 : i64, _read_only_resource_inputs = [1], _xla_propagate_compile_time_consts = true} : (tensor, tensor<128x1024xf32>) -> (tensor, tensor<128x1024xf32>) + %1 = "tf.XlaSharding"(%0#1) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + func.return %1 : tensor<128x1024xf32> +} + +// ----- + +// Tests sharding propagation in while region condition. +// CHECK-LABEL: func @check_sharding_for_read_variable_inside_while_cond +func.func @check_sharding_for_read_variable_inside_while_cond(%arg0 : tensor, %arg1: tensor<*x!tf_type.resource>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %1:1 = "tf_device.cluster_func"(%arg0, %0) {func = @func_with_sharding_inside_while_cond, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", num_cores_per_replica = 2 : i64, use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor, tensor<128x1024xf32>) -> (tensor<128x1024xf32>) + // CHECK: input_sharding_configuration + // CHECK-SAME: ["", "\0A\0B\0C"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\0D\0E\0F"] + func.return +} + +// CHECK-LABEL: func @func_with_sharding_inside_while_cond +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor {mhlo.sharding = ""}, %{{[a-z0-9]+}}: tensor<128x1024xf32> {mhlo.sharding = "\0A\0B\0C"}) +// CHECK-SAME: -> (tensor<128x1024xf32> {mhlo.sharding = "\0D\0E\0F"}) +func.func @func_with_sharding_inside_while_cond(%arg0: tensor, %arg1: tensor<128x1024xf32>) -> (tensor<128x1024xf32>) { + %cst = "tf.Const"() <{value = dense<0> : tensor}> {device = ""} : () -> tensor + %0:2 = "tf.WhileRegion"(%cst, %arg1) <{is_stateless = false, parallel_iterations = 1 : i64}> ({ + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.XlaSharding"(%arg3) <{_XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %2 = "tf.Less"(%arg2, %arg0) : (tensor, tensor) -> tensor + "tf.Yield"(%2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor<128x1024xf32>): + %1 = "tf.Square"(%arg3) : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + "tf.Yield"(%arg2, %1) : (tensor, tensor<128x1024xf32>) -> () + }) {_num_original_outputs = 1 : i64, _read_only_resource_inputs = [1], _xla_propagate_compile_time_consts = true} : (tensor, tensor<128x1024xf32>) -> (tensor, tensor<128x1024xf32>) + %1 = "tf.XlaSharding"(%0#1) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + func.return %1 : tensor<128x1024xf32> +} + +// ----- + // Tests with input sharding following an identity op and cast op. // CHECK-LABEL: func @check_sharding_after_cast_op func.func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { @@ -313,33 +379,6 @@ func.func @cluster_func(%arg0: tensor<*xf32>) { // ----- -// Tests that device variable sharding defaults to xla.OpSharding -// { type : MAXIMAL -// tile_assignment_dimensions: [ 1 ] -// tile_assignment_devices : [ 0 ] -// } - -// CHECK-LABEL: func @maximal_device_variable -func.func @maximal_device_variable(%arg0: tensor<*x!tf_type.resource>>) { - tf_device.replicate(%arg0 as %arg1: tensor<*x!tf_type.resource>>) - {_mirrored_variable_indices = [0], n = 2 : i32} { - %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<*xf32> - // CHECK: tf_device.cluster_func - // CHECK-SAME: input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"] - "tf_device.cluster_func"(%0) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64} : (tensor<*xf32>) -> () - tf_device.return - } - func.return -} - -// CHECK-LABEL: func @cluster_func -// CHECK-SAME: ({{.+}}: tensor<*xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -func.func @cluster_func(%arg0: tensor<*xf32>) { - func.return -} - -// ----- - // Tests that device variable sharding for an implicitly capture device variable // defaults to REPLICATE. @@ -564,6 +603,33 @@ func.func @func(%arg0: tensor<*xi32> {tf.aliasing_output = 1 : i64}, // ----- +// CHECK-LABEL: func @check_symmetric_alias_propagation +func.func @check_symmetric_alias_propagation(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["\01\02\03", "\04\05\06"] + // CHECK-SAME: output_sharding_configuration = ["\01\02\03", "\04\05\06"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = false, num_cores_per_replica = 1 : i64 + } : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) + func.return +} + +// CHECK-LABEL: func @func +// CHECK-SAME: %arg0: tensor<*xi32> {mhlo.sharding = "\01\02\03" +// CHECK-SAME: %arg1: tensor<*xi32> {mhlo.sharding = "\04\05\06" +// CHECK-SAME: ->{{.*}}mhlo.sharding = "\01\02\03"{{.*}}mhlo.sharding = "\04\05\06" +func.func @func(%arg0: tensor<*xi32> {tf.aliasing_output = 0 : i64}, + %arg1: tensor<*xi32> {tf.aliasing_output = 1 : i64}) -> (tensor<*xi32>, tensor<*xi32>) { + %0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>) + %2 = "tf.B"(%arg1) : (tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "\04\05\06"} : (tensor<*xi32>) -> tensor<*xi32> + func.return %2, %3 : tensor<*xi32>, tensor<*xi32> +} + +// ----- + // Partial tiled inputs using XlaSharding ops identified as REPLICATED should keep the sharding configuration. // The following xla.OpSharding is used: // Proto debug string: diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index bd660e8e8fcce5..66b81b0af9bb9a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -595,6 +595,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_side_effects", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tf_ops_layout_helper", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_inc_gen", @@ -627,11 +628,13 @@ cc_library( "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:variant", @@ -781,6 +784,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/ir/types:Dialect", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index e16082fad89c4c..71ac89140d7a32 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -19,15 +19,25 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h" #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -48,12 +58,19 @@ using OperationSetTy = SmallPtrSet; using ResourceToOpsMapTy = DenseMap; #define GEN_PASS_DEF_EXECUTORCONVERTCONTROLTODATAOUTPUTSPASS +#define GEN_PASS_DECL_EXECUTORCONVERTCONTROLTODATAOUTPUTSPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" -class ConvertControlToDataOutputsPass +struct ConvertControlToDataOutputsPass : public impl::ExecutorConvertControlToDataOutputsPassBase< ConvertControlToDataOutputsPass> { - public: + ConvertControlToDataOutputsPass() = default; + explicit ConvertControlToDataOutputsPass( + bool composite_tpuexecute_side_effects) + : ExecutorConvertControlToDataOutputsPassBase( + ExecutorConvertControlToDataOutputsPassOptions{ + composite_tpuexecute_side_effects}) {} + void runOnOperation() override; }; @@ -74,6 +91,41 @@ SmallVector GetWhileCallers(func::FuncOp func, return while_callers; } +bool IsResourceType(Type type) { + if (auto tensor_type = type.dyn_cast()) { + return tensor_type.getElementType().isa(); + } + return false; +} + +bool OnlyOperatesOnCompositeDevices(TF::TPUExecuteAndUpdateVariablesOp& op, + const DataFlowSolver& solver) { + llvm::SmallSet read_array; + for (const Attribute& attr : op.getDeviceVarReadsIndices()) { + read_array.insert(attr.cast().getInt()); + } + llvm::SmallSet update_array; + for (const Attribute& attr : op.getDeviceVarUpdatesIndices()) { + update_array.insert(attr.cast().getInt()); + } + for (auto& arg : op->getOpOperands()) { + if (!IsResourceType(arg.get().getType())) { + continue; + } + auto lattice = + solver.lookupState(arg.get()) + ->getValue(); + bool is_read = read_array.contains(arg.getOperandNumber()); + bool is_update = update_array.contains(arg.getOperandNumber()); + // We want the resource operands that are on composite devices to be the + // exact same set as the resource operands that are read or updated. + if ((is_read || is_update) != lattice.is_on_composite_device) { + return false; + } + } + return true; +} + // Populates `chain_resource_to_ops_map`, the map from all resources that need // to be chained to the set of operations that access the resource, and // `resource_equivalence_classes`. Resources are equivalent if they are accessed @@ -81,7 +133,8 @@ SmallVector GetWhileCallers(func::FuncOp func, void CollectChainResources( func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map, llvm::EquivalenceClasses& resource_equivalence_classes, - const TF::SideEffectAnalysis::Info& side_effect_analysis) { + const TF::SideEffectAnalysis::Info& side_effect_analysis, + const DataFlowSolver& solver, bool composite_tpuexecute_side_effects) { auto graph_op = cast(func.front().front()); // For each op in the graph, get the resources it uses and update the access @@ -93,6 +146,17 @@ void CollectChainResources( assert(island.WrapsSingleOp()); Operation& op = island.GetBody().front(); + // If the op only operates on resources stored on devices that are + // "COMPOSITE", then this op is defined to work in parallel with other + // TPUExecute* ops. So we don't need to track resources for it. + // TODO(b/325290168): Do this check per resource, not per op. + if (auto execute = llvm::dyn_cast(op)) { + if (composite_tpuexecute_side_effects && + OnlyOperatesOnCompositeDevices(execute, solver)) { + return WalkResult::advance(); + } + } + ResourceId prev_resource_id = kInvalidResourceId; for (auto resource_id_read_only_pair : side_effect_analysis.GetResourceIds(&op)) { @@ -113,6 +177,7 @@ void CollectChainResources( prev_resource_id = resource_id; } } + return WalkResult::advance(); }); } @@ -386,7 +451,8 @@ TF::WhileOp RewriteWhileOp(TF::WhileOp while_op, int num_resource_inputs, void ConvertControlToDataOutputs( func::FuncOp while_body, SmallVectorImpl& while_callers, OperationSetTy& recompute_analysis_for_funcs, - const TF::SideEffectAnalysis::Info& side_effect_analysis) { + const TF::SideEffectAnalysis::Info& side_effect_analysis, + const DataFlowSolver& solver, bool composite_tpuexecute_side_effects) { if (while_callers.empty()) return; // Collect access information for each resource in the while body that needs @@ -395,7 +461,8 @@ void ConvertControlToDataOutputs( ResourceToOpsMapTy chain_resource_to_ops_map; llvm::EquivalenceClasses resource_equivalence_classes; CollectChainResources(while_body, chain_resource_to_ops_map, - resource_equivalence_classes, side_effect_analysis); + resource_equivalence_classes, side_effect_analysis, + solver, composite_tpuexecute_side_effects); // Check for presence of unknown side-effecting ops within the while loop // body. These ops act as barriers and the optimization would not yield much @@ -459,6 +526,13 @@ void ConvertControlToDataOutputs( void ConvertControlToDataOutputsPass::runOnOperation() { ModuleOp module = getOperation(); + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(module))) return signalPassFailure(); + // This pass assumes that all functions are suitable for export i.e., each // function has a single tf_executor.graph op and all islands wrap the // internal op perfectly. Verify that in the beginning once. @@ -500,7 +574,8 @@ void ConvertControlToDataOutputsPass::runOnOperation() { } ConvertControlToDataOutputs( while_body, while_callers, recompute_analysis_for_funcs, - side_effect_analysis.GetAnalysisForFunc(while_body)); + side_effect_analysis.GetAnalysisForFunc(while_body), solver, + composite_tpuexecute_side_effects_); } } @@ -511,5 +586,12 @@ CreateTFExecutorConvertControlToDataOutputsPass() { return std::make_unique(); } +std::unique_ptr> +CreateTFExecutorConvertControlToDataOutputsPass( + bool composite_tpuexecute_side_effects) { + return std::make_unique( + composite_tpuexecute_side_effects); +} + } // namespace tf_executor } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index de4fabe59969c1..8a9e8451a91ba5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -55,9 +55,6 @@ def Clamp: NativeCodeCall< " $0.getLoc()," " $2.getType(), $2, $1, $3)">; -def CopyAttrs: NativeCodeCallVoid< - "CopyDeviceAndUnderscoredAttributesAdaptor($0, $1)">; - def DecomposeAssignAddVariableOp : Pat< (TF_AssignAddVariableOp:$src_op $resource, $value), @@ -341,12 +338,12 @@ def DecomposeResourceApplyAdamNesterov : def DecomposeResourceGather : Pat< (TF_ResourceGatherOp:$old_result $resource, $indices, $batch_dims, $validate_indices), - (TF_GatherV2Op:$dest + (TF_GatherV2Op (CreateTFReadVariableOp $old_result, $old_result, $resource), $indices, (TF_ConstOp $batch_dims), // axis $batch_dims - ), [], [(CopyAttrs $old_result, $dest)]>; + )>; // Pattern to decompose tf.ResourceScatterAdd into tf.ReadVariable, // tf.TensorScatterAdd, and tf.AssignVariable. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 125cbbd6163c33..0cff8946687dcb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -48,6 +48,11 @@ namespace { struct FunctionalControlFlowToRegions : public impl::FunctionalControlFlowToRegionsPassBase< FunctionalControlFlowToRegions> { + FunctionalControlFlowToRegions() = default; + explicit FunctionalControlFlowToRegions(bool allow_passthrough_args) + : FunctionalControlFlowToRegionsPassBase( + FunctionalControlFlowToRegionsPassOptions{allow_passthrough_args}) { + } void runOnOperation() override; }; @@ -251,6 +256,11 @@ std::unique_ptr> CreateTFFunctionalControlFlowToRegions() { return std::make_unique(); } +std::unique_ptr> CreateTFFunctionalControlFlowToRegions( + bool allow_passthrough_args) { + return std::make_unique( + allow_passthrough_args); +} } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index a7fa9268e152d0..1c0a125598cdbe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -330,6 +330,16 @@ class FuseMatMulBiasAdd }); return false; } + // FusedMatMul kernel does not support grad_a/grad_b attrs + if ((matmul->hasAttr("grad_a") && + matmul->getAttr("grad_a").cast().getValue()) || + (matmul->hasAttr("grad_b") && + matmul->getAttr("grad_b").cast().getValue())) { + (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) { + diag << "FusedMatMul kernel does not support grad_a/grad_b attrs"; + }); + return false; + } return true; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index a52962e5f7f326..9c475f1f9f5281 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -80,6 +80,8 @@ CreateTFFunctionalControlFlowToCFG(); // their region based counterparts. std::unique_ptr> CreateTFFunctionalControlFlowToRegions(); +std::unique_ptr> CreateTFFunctionalControlFlowToRegions( + bool allow_passthrough_args); // Transforms region bases control flow operations in the TensorFlow dialect to // their functional counterparts. @@ -339,6 +341,9 @@ namespace tf_executor { // Creates a pass to chain control outputs of while loop body. std::unique_ptr> CreateTFExecutorConvertControlToDataOutputsPass(); +std::unique_ptr> +CreateTFExecutorConvertControlToDataOutputsPass( + bool composite_tpuexecute_side_effects); std::unique_ptr> CreateTFExecutorCheckControlDependenciesPass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 56fcdd761999b1..661dafe2a2f327 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -47,8 +47,8 @@ bool IsCommunicationOp(Operation* op) { // subcomputation in the TF/XLA bridge. bool SupportsCommunicationComputation(Operation* op) { return isa(op); + TF::XlaCallModuleOp, TF::StatefulPartitionedCallOp, + TF::PartitionedCallOp, TF::LegacyCallOp>(op); } #define GEN_PASS_DEF_PREPARETPUCOMPUTATIONFORTFEXPORTPASS @@ -65,14 +65,17 @@ class RewriteXlaHostComputeMlir public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op, - PatternRewriter& rewriter) const override { + LogicalResult match(TF::_XlaHostComputeMlirOp op) const override { if (op.getManualSharding()) { - op.emitOpError() << "manual_sharding not supported with fallback of " - "phase 2 legalize TF/XLA bridge. manual_sharding is " - "used by map_outside_compilation"; + // This rewrite does not support manual_sharding. It is expected that the + // _XlaHostComputeMlirOp registered as an MlirXlaOpKernel will handle this + // case later once the XlaBuilder graph reaches it. return failure(); } + return success(); + } + void rewrite(TF::_XlaHostComputeMlirOp op, + PatternRewriter& rewriter) const override { llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { @@ -132,7 +135,6 @@ class RewriteXlaHostComputeMlir op.getRecvKeyAttr(), /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate), /*tpu_core=*/rewriter.getI64IntegerAttr(0)); - return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index 78f2c3e0423124..a669276e35a175 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -183,7 +183,7 @@ StringRef ExtractSingleBlockRegion( } ModuleOp module = region.getParentOfType(); - auto builder = OpBuilder::atBlockBegin(module.getBody()); + OpBuilder builder(module.getContext()); auto loc = region.getParentOp()->getLoc(); Block& entry = region.front(); int num_region_arguments = entry.getNumArguments(); @@ -236,8 +236,9 @@ StringRef ExtractSingleBlockRegion( outlined_func.setPrivate(); - // Uniquify the function name. - symbol_table.getSymbolTable(module).insert(outlined_func); + // Uniquify the function name, and insert into module. + symbol_table.getSymbolTable(module).insert(outlined_func, + module.getBody()->begin()); // Add the outlined function to the worklist in case its body has // IfRegion or WhileRegion ops that need to converted. @@ -246,9 +247,10 @@ StringRef ExtractSingleBlockRegion( } // Returns call for region with single call whose result feeds into the -// terminator of the region. if `allow_to_bool` is true, also allows a single -// ToBoolOp between the region yield and the call. Returns none if the region -// does not conform to this pattern. +// terminator of the region. If `allow_to_bool` is true, it allows patterns used +// in the condition of While ops, i.e. it allows a single bool (possibly passed +// through a ToBoolOp) between the region yield and the call. Returns none if +// the region does not conform to this pattern. std::optional IsSingleCallRegion(Region& region, bool allow_to_bool = false) { if (!llvm::hasSingleElement(region)) return std::nullopt; @@ -275,10 +277,23 @@ std::optional IsSingleCallRegion(Region& region, func::CallOp call = dyn_cast(*it++); if (!call) return std::nullopt; - // All call results should feed into expected consumer - // All results of the call should feed into the yield. - if (call.getNumResults() != call_consumer->getNumOperands()) - return std::nullopt; + if (allow_to_bool && call.getNumResults() == 1 && + yield->getNumOperands() != 1) { + // Allow patterns of the form + // %cond = call(...) + // yield %cond, [...passthrough args...] + if (yield->getNumOperands() != block.getNumArguments() + 1) + return std::nullopt; + for (auto [yield_operand, block_arg] : + llvm::zip(yield->getOperands().drop_front(1), block.getArguments())) { + if (yield_operand != block_arg) return std::nullopt; + } + } else { + // All call results should feed into expected consumer + // All results of the call should feed into the yield. + if (call.getNumResults() != call_consumer->getNumOperands()) + return std::nullopt; + } for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands())) if (std::get<0>(res_it) != std::get<1>(res_it)) return std::nullopt; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index c18debb978fe0f..dc1cfe3f5920fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -28,7 +28,9 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" @@ -67,6 +69,7 @@ limitations under the License. #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -434,6 +437,24 @@ Type GetType(Attribute shape_attr, Attribute type_attr) { else return UnrankedTensorType::get(type.getValue()); } +} // namespace + +// Returns whether type can be further refined. +bool CanBeRefined(Type type) { + auto shape_type = type.dyn_cast(); + if (!shape_type) return false; + + // Returns whether type with subtypes can be further refined. + auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) { + return tws.GetSubtypes().empty() || + llvm::any_of(tws.GetSubtypes(), CanBeRefined); + }; + auto type_with_subtype = + shape_type.getElementType().dyn_cast(); + if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true; + + return !shape_type.hasStaticShape(); +} // Returns a new arg type based on the shape and element type. If there are // dynamic bounds attribute to the arg, update the bounds based on the shape @@ -464,25 +485,6 @@ Type GetNewArgType(Type old_arg_type, ArrayRef shape, return new_arg_type; } -} // namespace - -// Returns whether type can be further refined. -bool CanBeRefined(Type type) { - auto shape_type = type.dyn_cast(); - if (!shape_type) return false; - - // Returns whether type with subtypes can be further refined. - auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) { - return tws.GetSubtypes().empty() || - llvm::any_of(tws.GetSubtypes(), CanBeRefined); - }; - auto type_with_subtype = - shape_type.getElementType().dyn_cast(); - if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true; - - return !shape_type.hasStaticShape(); -} - // Combination of value producer and port of value produced (e.g., // :, // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output @@ -728,7 +730,8 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, class ShapeInference { public: ShapeInference(int64_t graph_version, ModuleOp module, - bool propagate_caller_callee_constants); + bool propagate_caller_callee_constants, + ArrayRef ops_to_skip); LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, ValuePortInputs* inputs) { @@ -901,8 +904,8 @@ class ShapeInference { bool RefineResultType(Operation* op, Value result, Type potential_refined_type); - // Infers the shape from a (Stateful)PartionedCall operation by looking up the - // called function and propagating the return type. + // Infers the shape from a (Stateful)PartitionedCall operation by looking up + // the called function and propagating the return type. bool InferShapeForCall(CallOpInterface call_op); bool InferShapeForCast(Operation* op); @@ -1006,6 +1009,9 @@ class ShapeInference { int64_t graph_version_; + // Op types for which shape inference should be skipped. + llvm::SmallDenseSet ops_to_skip_; + // TODO(b/154065712): Remove propagate_caller_callee_constants once using // SCCP pass instead. bool propagate_caller_callee_constants_; @@ -1020,11 +1026,15 @@ class ShapeInference { }; ShapeInference::ShapeInference(int64_t graph_version, ModuleOp module, - bool propagate_caller_callee_constants) + bool propagate_caller_callee_constants, + ArrayRef ops_to_skip) : tf_dialect_(module->getContext()->getLoadedDialect()), symbol_users_(symbol_table_, module), graph_version_(graph_version), propagate_caller_callee_constants_(propagate_caller_callee_constants) { + for (const auto& op_type : ops_to_skip) { + ops_to_skip_.insert(op_type); + } // Create symbol table for module. symbol_table_.getSymbolTable(module); } @@ -1079,8 +1089,8 @@ bool ShapeInference::RefineResultType(Operation* op, Value result, result); } -// Infers the shape from a (Stateful)PartionedCall operation by looking up the -// called function and propagating the return type. +// Infers the shape from a (Stateful)PartitionedCall operation by looking up +// the called function and propagating the return type. bool ShapeInference::InferShapeForCall(CallOpInterface call_op) { func::FuncOp func = dyn_cast_or_null(call_op.resolveCallable(&symbol_table_)); @@ -1256,8 +1266,6 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { tsl::Status status = loader->RefineDynamicShapes(input_shapes); if (!status.ok()) { - llvm::errs() << "Failed during XlaCallModule shape refinement: " - << status.ToString(); // Do not return false here. // // RefineDynamicShapes returns ok only when it produces full static shapes. @@ -1266,6 +1274,7 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { // to abort here. // TODO(b/316639984): improve RefineDynamicShapes return values to include // these info. + VLOG(1) << "Failed during XlaCallModule shape refinement: " << status; } mlir::ResultRange op_results = op.getResults(); // The main_outputs may include tokens that are not among the op_results; @@ -1972,8 +1981,17 @@ bool ShapeInference::InferShapeForXlaGatherOp(XlaGatherOp op) { return false; DenseIntElementsAttr slice_sizes_attr; - if (!matchPattern(op.getSliceSizes(), m_Constant(&slice_sizes_attr))) + if (DenseIntElementsAttr attr; + matchPattern(op.getSliceSizes(), m_Constant(&attr))) { + slice_sizes_attr = attr; + } else if (const auto it = results_.find(ValuePort(op.getSliceSizes())); + it != results_.end() && + llvm::isa_and_nonnull(it->second)) { + slice_sizes_attr = llvm::cast(it->second); + } else { return false; + } + llvm::SmallVector slice_sizes; for (const auto& attr : slice_sizes_attr.getValues()) { slice_sizes.push_back(attr.getSExtValue()); @@ -3135,6 +3153,14 @@ FailureOr ShapeInference::InferShapeUntilFixPoint( LLVM_DEBUG(llvm::dbgs() << "Shape inference, iteration " << iteration << "\n"); auto res = region->walk([&](Operation* op) { + auto abstract_op = op->getRegisteredInfo(); + if (abstract_op && ops_to_skip_.contains(abstract_op->getTypeID())) { + LLVM_DEBUG(llvm::dbgs() + << "Skipping shape inference for explicitly skipped op '" + << op->getName() << "'.\n"); + return WalkResult::advance(); + } + DCOMMENT_OP(op, "Inferring for"); if (auto infer_ti = dyn_cast(op)) { DCOMMENT("\tRefinining with type op interface"); @@ -3201,9 +3227,11 @@ static FailureOr InferShapeForFunction(ShapeInference& context, FailureOr InferShapeForFunction(func::FuncOp func, ArrayRef> arg_shapes, int64_t graph_version, - int64_t max_iterations) { + int64_t max_iterations, + ArrayRef ops_to_skip) { ShapeInference context(graph_version, func->getParentOfType(), - /*propagate_caller_callee_constants=*/true); + /*propagate_caller_callee_constants=*/true, + ops_to_skip); if (arg_shapes.empty()) { return InferShapeForFunction(context, func, max_iterations); } @@ -3255,7 +3283,8 @@ FailureOr InferShapeForFunction(func::FuncOp func, return true; } -FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations) { +FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations, + ArrayRef ops_to_skip) { auto producer_or = tensorflow::GetTfGraphProducerVersion(module); if (!producer_or.ok()) { // TODO(jpienaar): Keeping the existing behavior for now but this could @@ -3268,7 +3297,8 @@ FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations) { // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if // it is no longer needed. ShapeInference context(producer, module, - /*propagate_caller_callee_constants=*/false); + /*propagate_caller_callee_constants=*/false, + ops_to_skip); if (auto main = module.lookupSymbol("main")) context.enqueue(main); for (auto func : module.getOps()) context.enqueue(func); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index e12c470d53f4cb..bc1cf7b3c8f475 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -20,9 +20,11 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project namespace mlir { namespace TF { @@ -30,13 +32,22 @@ namespace TF { // Returns whether type can be further refined. bool CanBeRefined(Type type); -// Refines all the shapes in a module. +// Returns a new arg type based on the shape and element type. If there are +// dynamic bounds attribute to the arg, update the bounds based on the shape +// as well. +Type GetNewArgType(Type old_arg_type, ArrayRef shape, + Type element_type, mlir::MLIRContext* context); + +// Refines all the shapes in a module, skipping the inference for all ops +// whose type is in ops_to_skip. // Returns a failure() on error, otherwise returns true to indicate that it // reached convergence, false otherwise. -FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations = 10); +FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations = 10, + ArrayRef ops_to_skip = {}); // Given a list of refined shapes matching the function arguments of func, runs -// shape inference over the function to propagate this updated information. +// shape inference over the function to propagate this updated information, +// skipping the inference for all ops whose type is in ops_to_skip. // If arg_shapes are empty, then argument shapes will be left unchanged. // Note: This affects the entire module, and changes are not just scoped to the // function being inferred. @@ -45,7 +56,8 @@ FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations = 10); FailureOr InferShapeForFunction(func::FuncOp func, ArrayRef> arg_shapes, int64_t graph_version, - int64_t max_iterations = 10); + int64_t max_iterations = 10, + ArrayRef ops_to_skip = {}); } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 8180e2d4084dba..37bcb46b95cc57 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -38,7 +38,7 @@ class ShapeInference public: void runOnOperation() override { auto failure_or_converged = - InferModuleShape(getOperation(), max_iterations_); + InferModuleShape(getOperation(), max_iterations_, /*ops_to_skip=*/{}); if (failed(failure_or_converged)) return signalPassFailure(); if (!failure_or_converged.value()) { getOperation().emitError() diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index f50286123f2478..de47fbaca69635 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -516,6 +516,13 @@ def ExecutorConvertControlToDataOutputsPass : Pass<"tf-executor-convert-control- }]; let constructor = "tf_executor::CreateTFExecutorConvertControlToDataOutputsPass()"; + + let options = [ + Option<"composite_tpuexecute_side_effects_", "composite-tpuexecute-side-effects", "bool", + /*default=*/"false", + "Enables certain TPUExecute ops to run in parallel if they only " + "operate on resources that live on composite devices."> + ]; } def ExecutorUpdateControlDependenciesPass : Pass<"tf-executor-update-control-dependencies", "ModuleOp"> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 5366f8e0f92702..3edb71bf08e627 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -17,8 +17,13 @@ limitations under the License. #include #include #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -31,6 +36,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -41,9 +47,11 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "xla/client/sharding_builder.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace mlir { namespace TFTPU { @@ -51,6 +59,8 @@ namespace { using OpShardingVariant = std::variant; using OpShardingVector = llvm::SmallVector; +using OptionalOpShardingVector = + llvm::SmallVector, 8>; constexpr char kReplicateSharding[] = ""; constexpr char kShardingAttr[] = "mhlo.sharding"; @@ -148,16 +158,6 @@ mlir::ArrayAttr GetStrArrayAttr(Builder* builder, return builder->getArrayAttr(strings); } -// Given a `tf_device.cluster_func` operand value return true iff it a device -// variable that should default to MAXIMAL sharding. Device variables that are -// per-replica or distributed default to MAXIMAL sharding, which corresponds to -// arguments of the `tf_device.replicate`. Otherwise the variable is broadcast, -// which corresponds to edges that are implicitly captured by the `replicate`. -bool IsMaximalVariable(Value value) { - auto read_var = value.getDefiningOp(); - return read_var && read_var->getParentOfType(); -} - // Verify whether the given sharding can be applied to the given (tensor) type. // (A bad sharding might mean failing tf.Split ops if the graph later executes // on CPU) @@ -248,7 +248,7 @@ std::optional AssignLogicalDeviceFromTPUReplicatedCoreAttr( // Cast op may be added right after the input. // // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, -// Case, While) ops and Caller return values. +// Case) ops and Caller return values. // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded // inputs. std::optional GetXlaShardingFromArg( @@ -270,6 +270,15 @@ std::optional GetXlaShardingFromArg( return logical_device; } + if (auto while_op = llvm::dyn_cast(owner)) { + const int operand_number = use.getOperandNumber(); + next_values_to_visit.push_back( + while_op.getCond().front().getArgument(operand_number)); + next_values_to_visit.push_back( + while_op.getBody().front().getArgument(operand_number)); + continue; + } + if (llvm::isa(owner)) { next_values_to_visit.push_back(use.getOwner()->getResult(0)); continue; @@ -291,15 +300,16 @@ std::optional GetXlaShardingFromArg( return std::nullopt; } -// Extracts sharding configurations for all inputs by parsing XlaSharding/ -// TPUPartitionedInput op connected to the operands/arguments. If argument to -// the `cluster_func` directly feeds into another function call op, then -// recursively walk the function definition to find the connected XlaSharding -// op. +// Tries to extract sharding configurations for all inputs by parsing +// XlaSharding/ TPUPartitionedInput op connected to the operands/arguments. If +// argument to the `cluster_func` directly feeds into another function call op, +// then recursively walk the function definition to find the connected +// XlaSharding op. void IdentifyXlaShardingForComputationInputs( - const llvm::SmallVector& logical_device_vec, bool use_spmd, + const llvm::SmallVector& logical_device_vec, bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, - func::FuncOp func, Builder* builder, OpShardingVector& sharding_for_args) { + func::FuncOp func, Builder* builder, + OptionalOpShardingVector& sharding_for_args) { // Look up function definition from module. Block& function_block = func.front(); @@ -310,8 +320,6 @@ void IdentifyXlaShardingForComputationInputs( // 1) a TPUPartitionedInput Op if the input has a non-resource type; // 2) a ReadVariableOp else. // - // Replicate sharding is used if `use_spmd` is set. - // // Iterate through input arguments to the entry block of // tf_device.ClusterFunc. For input ops, look for XlaSharding ops. // XlaSharding ops can: @@ -340,17 +348,7 @@ void IdentifyXlaShardingForComputationInputs( } } - if (use_spmd && !IsMaximalVariable(operand)) { - // If XLA SPMD is enabled, host variables or non-variable per-replica - // inputs should take on replicate sharding, so that every device gets the - // whole tensor(s) (and can slice them up later). Exclude device - // variables, which always should take maximal sharding. - sharding_for_args.push_back(kReplicateSharding); - continue; - } - - // Otherwise, default to maximal sharding core 0. - sharding_for_args.push_back(logical_device_vec[0]); + sharding_for_args.push_back(std::nullopt); } } @@ -374,31 +372,32 @@ mlir::Operation* GetXlaShardingFromResult(Value value) { return nullptr; } -// Looks up arg->retval aliases for every argument, and builds a reverse map. -void ExtractAliases(func::FuncOp func, llvm::SmallVectorImpl& aliases) { - aliases.resize(func.getNumResults(), -1); - for (int i = 0; i < func.getNumArguments(); i++) { - if (auto v = func.getArgAttrOfType(i, kAliasingAttr)) { - int retval_index = v.getInt(); - if (retval_index >= 0 && retval_index < aliases.size()) { - aliases[retval_index] = i; +absl::Status DetermineShardingFromAlias( + func::FuncOp func, OptionalOpShardingVector& input_shardings, + OptionalOpShardingVector& output_shardings) { + for (int arg_idx = 0; arg_idx < func.getNumArguments(); ++arg_idx) { + if (auto v = + func.getArgAttrOfType(arg_idx, kAliasingAttr)) { + if (int retval_idx = v.getInt(); + retval_idx >= 0 && retval_idx < func.getNumResults()) { + auto& input_sharding = input_shardings[arg_idx]; + auto& output_sharding = output_shardings[retval_idx]; + + if (input_sharding.has_value() && output_sharding.has_value() && + input_sharding.value() != output_sharding.value()) { + return absl::InvalidArgumentError(absl::StrCat( + "arg#", arg_idx, " is aliased to retval#", retval_idx, + " but their sharding configurations don't match.")); + } else if (input_sharding.has_value() && !output_sharding.has_value()) { + output_sharding = input_sharding; + } else if (!input_sharding.has_value() && output_sharding.has_value()) { + input_sharding = output_sharding; + } } } } -} -// Returns XLA sharding from argument connected via tf.aliasing_output. -std::optional GetXlaShardingFromAlias( - Value value, llvm::SmallVectorImpl& aliases, - const OpShardingVector& sharding_for_args) { - int retval_index = value.cast().getResultNumber(); - if (retval_index >= 0 && retval_index < aliases.size()) { - int arg_index = aliases[retval_index]; - if (arg_index >= 0 && arg_index < sharding_for_args.size()) { - return GetShardingStringFromVariant(sharding_for_args[arg_index]); - } - } - return std::nullopt; + return absl::OkStatus(); } // Returns XLA sharding from XlaSharding op connected to a result value. @@ -471,26 +470,20 @@ std::optional GetXlaShardingFromRetval( return std::nullopt; } -// Extracts sharding configurations for all outputs by parsing XlaSharding/ -// TPUPartitionedOutput op connected to the retvals/results. +// Tries to extract sharding configurations for all outputs by parsing +// XlaSharding/ TPUPartitionedOutput op connected to the retvals/results. void IdentifyXlaShardingForComputationOutputs( - const llvm::SmallVector& logical_device_vec, bool use_spmd, + const llvm::SmallVector& logical_device_vec, bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, func::FuncOp func, Builder* builder, - const OpShardingVector& sharding_for_args, - OpShardingVector& sharding_for_rets) { + OptionalOpShardingVector& sharding_for_rets) { Block& function_block = func.front(); Operation* terminator = function_block.getTerminator(); sharding_for_rets.reserve(terminator->getNumOperands()); - llvm::SmallVector aliases; // maps return value index to arg index - ExtractAliases(func, aliases); - // Iterate through results of `cluster_func`. For output ops, look for // TPUPartitionedOutput ops. // - // Replicate sharding is used if `use_spmd` is set. - // // Iterate through operands of the terminator. If the preceding op is // XlaShardingOp, then the provided sharding configuration is added to the // tf_device.ClusterFunc as an attribute and the function as a result @@ -505,12 +498,6 @@ void IdentifyXlaShardingForComputationOutputs( continue; } - if (auto from_alias = - GetXlaShardingFromAlias(result, aliases, sharding_for_args)) { - sharding_for_rets.push_back(from_alias.value()); - continue; - } - if (infer_from_computation) { if (auto retval_sharding = GetXlaShardingFromRetval(retval.get(), logical_device_vec)) { @@ -519,18 +506,76 @@ void IdentifyXlaShardingForComputationOutputs( } } - if (use_spmd) { - // If XLA SPMD is enabled, we default to replicate sharding. This way, - // all devices get the whole tensor(s), but if there's an XlaSharding op - // deeper in the function, they can use dynamic-slice to slice off their - // part of the computation. - sharding_for_rets.push_back(kReplicateSharding); - continue; + sharding_for_rets.push_back(std::nullopt); + } +} + +void SetReplicatedOrMaximalShardingIfNoShardingFound( + const llvm::SmallVector& logical_device_vec, bool use_spmd, + OptionalOpShardingVector& shardings) { + for (auto& sharding : shardings) { + if (sharding == std::nullopt) { + // If we haven't found sharding, default to either replicated or maximal + // sharding depending on whether XLA SPMD is enabled. + if (use_spmd) { + // If XLA SPMD is enabled, host variables or non-variable per-replica + // inputs, and outputs should take on replicate sharding, so that every + // device gets the whole tensor(s) (and can slice them up later eg. + // using dynamic-slice). + sharding = kReplicateSharding; + } else { + // Otherwise, default to maximal sharding core 0. + sharding = logical_device_vec[0]; + } + } + } +} + +// Moves shardings from `optional_shardings` to `shardings`. +absl::Status MoveSharding(OptionalOpShardingVector& optional_shardings, + OpShardingVector& shardings) { + shardings.clear(); + for (auto& sharding : optional_shardings) { + if (!sharding) { + return absl::InternalError( + "Couldn't find/assign sharding for an input/output. All shardings " + "should have been identified by this point."); } - // Otherwise, default to maximal sharding core 0. - sharding_for_rets.push_back(logical_device_vec[0]); + shardings.push_back(std::move(sharding.value())); } + + return absl::OkStatus(); +} + +// Determines XlaSharding for inputs and outputs. If there are aliased +// inputs/outputs for which no sharding was found directly, the corresponding +// output/input sharding is used (if it exists). If we still don't find sharding +// for some inputs/outputs, we default to replicated or maximal sharding +// depending on `use_spmd`. +absl::Status IdentifyXlaShardingForInputsAndOutputs( + const llvm::SmallVector& logical_device_vec, bool use_spmd, + bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, + func::FuncOp func, Builder* builder, OpShardingVector& input_sharding, + OpShardingVector& output_sharding) { + OptionalOpShardingVector optional_input_sharding; + OptionalOpShardingVector optional_output_sharding; + IdentifyXlaShardingForComputationInputs( + logical_device_vec, infer_from_computation, cluster_func, func, builder, + optional_input_sharding); + IdentifyXlaShardingForComputationOutputs( + logical_device_vec, infer_from_computation, cluster_func, func, builder, + optional_output_sharding); + TF_RETURN_IF_ERROR(DetermineShardingFromAlias(func, optional_input_sharding, + optional_output_sharding)); + SetReplicatedOrMaximalShardingIfNoShardingFound(logical_device_vec, use_spmd, + optional_input_sharding); + SetReplicatedOrMaximalShardingIfNoShardingFound(logical_device_vec, use_spmd, + optional_output_sharding); + TF_RETURN_IF_ERROR(MoveSharding(optional_input_sharding, input_sharding)); + TF_RETURN_IF_ERROR(MoveSharding(optional_output_sharding, output_sharding)); + + return absl::OkStatus(); } // Extracts input/output sharding configuration of `cluster_func` by parsing @@ -561,15 +606,15 @@ LogicalResult IdentifyXlaShardingForTPUComputation( } OpShardingVector sharding_for_args; - IdentifyXlaShardingForComputationInputs(logical_device_vec, use_spmd, - /*infer_from_computation=*/true, - cluster_func, func, builder, - sharding_for_args); - OpShardingVector sharding_for_rets; - IdentifyXlaShardingForComputationOutputs( - logical_device_vec, use_spmd, /*infer_from_computation=*/true, - cluster_func, func, builder, sharding_for_args, sharding_for_rets); + if (auto status = IdentifyXlaShardingForInputsAndOutputs( + logical_device_vec, use_spmd, + /*infer_from_computation=*/true, cluster_func, func, builder, + sharding_for_args, sharding_for_rets); + !status.ok()) { + LOG(ERROR) << status; + return failure(); + }; auto has_maximal_sharding = [](const OpShardingVariant& sharding_or_op) -> bool { @@ -592,14 +637,14 @@ LogicalResult IdentifyXlaShardingForTPUComputation( sharding_for_rets.clear(); cluster_func->setAttr(kUseSpmdAttr, builder->getBoolAttr(false)); - IdentifyXlaShardingForComputationInputs( - logical_device_vec, - /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func, - func, builder, sharding_for_args); - IdentifyXlaShardingForComputationOutputs( - logical_device_vec, - /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func, - func, builder, sharding_for_args, sharding_for_rets); + if (auto status = IdentifyXlaShardingForInputsAndOutputs( + logical_device_vec, /*use_spmd=*/false, + /*infer_from_computation=*/false, cluster_func, func, builder, + sharding_for_args, sharding_for_rets); + !status.ok()) { + LOG(ERROR) << status; + return failure(); + } } // Update sharding on function arguments and returns. @@ -645,7 +690,7 @@ void TPUShardingIdentificationPass::runOnOperation() { if (result.wasInterrupted()) return signalPassFailure(); } -} // anonymous namespace +} // namespace std::unique_ptr> CreateTPUShardingIdentificationPass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc index 615911f0cbd741..9f36e838206804 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc @@ -116,7 +116,8 @@ LogicalResult FillOpToParallelIdsMap( if (id_pairs.empty()) continue; TF::ParallelIdsMap& ids_map = op_to_parallel_ids_map[island]; - for (auto [group_id, branch_id] : id_pairs) ids_map[group_id] = branch_id; + for (const auto& [group_id, branch_id] : id_pairs) + ids_map[group_id] = branch_id; } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 4160c31515cf62..95c67e6084d90a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -182,6 +182,17 @@ void LoadImporterDialects(mlir::MLIRContext& context) { context.getOrLoadDialect(name); } +absl::StatusOr GetDenseTensorNameFromTensorInfo( + const TensorInfo& tensor_info) { + // TODO(b/184675681): Support other encoding cases. + // + // TODO(b/184679394): Add unit test for this check. + TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName) + << "Only dense tensor is supported, but got encoding case " + << tensor_info.encoding_case(); + return tensor_info.name(); +} + // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings // generated by the class object during its lifetime. @@ -3945,7 +3956,11 @@ SavedModelSignatureDefImporterLite::ConvertGraph( specs.graph_func_name = name; specs.prune_unused_nodes = true; TF_ASSIGN_OR_RETURN(specs.inputs, ParseInputArrays(inputs)); - for (auto& output : outputs) specs.outputs.push_back(output.second.name()); + for (auto& output : outputs) { + TF_ASSIGN_OR_RETURN(std::string name, + GetDenseTensorNameFromTensorInfo(output.second)); + specs.outputs.push_back(std::move(name)); + } specs.control_outputs = control_outputs; specs.enable_shape_inference = false; specs.unconditionally_use_set_output_shapes = @@ -4031,12 +4046,8 @@ SavedModelSignatureDefImporterLite::ParseInputArrays( for (const auto& iter : inputs) { const auto& tensor_info = iter.second; - // TODO(b/184675681): Support other encoding cases. - // - // TODO(b/184679394): Add unit test for this check. - TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName) - << "Only dense tensor is supported, but got encoding case " - << tensor_info.encoding_case(); + TF_ASSIGN_OR_RETURN(std::string name, + GetDenseTensorNameFromTensorInfo(tensor_info)); VLOG(1) << "Importing Signature Input: input_name = " << iter.first << ", tensor_info = " << tensor_info.DebugString(); @@ -4052,7 +4063,7 @@ SavedModelSignatureDefImporterLite::ParseInputArrays( array_info.shape.set_unknown_rank(true); } - results.insert(std::pair(tensor_info.name(), + results.insert(std::pair(std::move(name), std::move(array_info))); } return results; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 693d1f37766d81..4efec99e6ecd4d 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -100,6 +100,7 @@ cc_library( hdrs = ["compile_tf_graph.h"], deps = [ "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:bridge_logger", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_model", @@ -108,6 +109,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:translate_utils", "//tensorflow/compiler/mlir/tensorflow/transforms:set_tpu_infeed_layout", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", @@ -142,6 +144,9 @@ tf_cc_test( name = "compile_tf_graph_test", testonly = 1, srcs = ["compile_tf_graph_test.cc"], + data = [ + "testdata/prepare_to_library.mlir", + ], linkstatic = 1, deps = [ ":compile_tf_graph", @@ -161,6 +166,7 @@ tf_cc_test( "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/core/tpu/kernels/xla:host_compute_ops", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@llvm-project//mlir:FuncDialect", @@ -171,7 +177,6 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla/client:client_library", - "@local_xla//xla/stream_executor", "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 003732ffb22f5a..0355204506068c 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -62,6 +64,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/tpu_compile.h" +#include "tensorflow/core/util/debug_data_dumper.h" #include "tsl/lib/monitoring/sampler.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" @@ -170,11 +173,21 @@ Status PrepareAndExportToLibrary(mlir::ModuleOp module, applyTensorflowAndCLOptions(manager); manager.addPass(mlir::TF::CreatePrepareTpuComputationForTfExportPass()); manager.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + manager.addPass(mlir::TF::CreateTFShapeInferencePass()); manager.addNestedPass( mlir::CreateFunctionalToExecutorDialectConversionPass()); manager.addPass(mlir::CreateBreakUpIslandsPass()); mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + + if (VLOG_IS_ON(2)) { + llvm::StringRef module_name = llvm::StringRef(); + constexpr const char* kDebugGroupBridgePhase2 = + "v1_prepare_and_export_to_library"; + internal::EnablePassIRPrinting(manager, kDebugGroupBridgePhase2, + module_name); + } + auto prepare_status = manager.run(module); auto diag_handler_status = diag_handler.ConsumeStatus(); // There are cases where the scoped diagnostic handler catches a failure that diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc index e3b2339eedbba5..9bea916088c133 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" +#include #include #include @@ -213,9 +214,9 @@ TEST_F(CompileTFGraphTest, RecordsStreamzForFunctionToHlo) { EXPECT_EQ(compilation_status.Delta("kOldBridgeNoMlirSuccess"), 1); } -TEST_F(CompileTFGraphTest, CatchesErrorMissedByPassManagerRun) { +TEST_F(CompileTFGraphTest, SuccessfullyCompilesWithManualSharding) { // MLIR module from failing test. - constexpr char kUnsupportedManualSharding[] = R"( + constexpr char kSupportedManualSharding[] = R"( module @module___inference_tpu_function_41 attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1617 : i32}} { func.func @main(%arg0: tensor<2x2xf32>) -> (tensor<2x2xf32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) { %0 = tf_executor.graph { @@ -223,7 +224,7 @@ TEST_F(CompileTFGraphTest, CatchesErrorMissedByPassManagerRun) { %outputs_0, %control_1 = tf_executor.island wraps "tf.XlaSharding"(%outputs) {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01", sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xf32>) -> tensor<2x2xf32> %outputs_2, %control_3 = tf_executor.island wraps "tf.XlaSpmdFullToShardShape"(%outputs_0) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xf32>) -> tensor<1x2xf32> %control_4 = tf_executor.island wraps "tf._XlaHostComputeMlir"(%outputs_2) {host_mlir_module = "", manual_sharding = true, recv_key = "host_compute_channel_0_retvals", send_key = "host_compute_channel_0_args"} : (tensor<1x2xf32>) -> () - %outputs_5, %control_6 = tf_executor.island(%control_4) wraps "tf._XlaHostComputeMlir"() {host_mlir_module = "", manual_sharding = true, recv_key = "host_compute_channel_1_retvals", send_key = "host_compute_channel_1_args"} : () -> tensor<1x2xf32> + %outputs_5, %control_6 = tf_executor.island(%control_4) wraps "tf._XlaHostComputeMlir"() {host_mlir_module = "module {\0A func.func @host_func() -> tensor<1x2xf32> {\0A %0 = \22tf.Const\22() {value = dense<0.1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> \0A return %0 : tensor<1x2xf32>}}", manual_sharding = true, recv_key = "host_compute_channel_1_retvals", send_key = "host_compute_channel_1_args"} : () -> tensor<1x2xf32> %outputs_7, %control_8 = tf_executor.island wraps "tf.XlaSpmdShardToFullShape"(%outputs_5) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xf32>) -> tensor<2x2xf32> %outputs_9, %control_10 = tf_executor.island wraps "tf.XlaSharding"(%outputs_7) {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01", sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xf32>) -> tensor<2x2xf32> tf_executor.fetch %outputs_9 : tensor<2x2xf32> @@ -232,13 +233,11 @@ TEST_F(CompileTFGraphTest, CatchesErrorMissedByPassManagerRun) { } } )"; - auto mlir_to_hlo_args = CreateTestMlirToHloArgs(kUnsupportedManualSharding); + auto mlir_to_hlo_args = CreateTestMlirToHloArgs(kSupportedManualSharding); auto result = CompileWithComputation(mlir_to_hlo_args); - ASSERT_THAT(result.ok(), false); - EXPECT_THAT(result.status().message(), - testing::ContainsRegex("op manual_sharding")); + EXPECT_TRUE(result.ok()); } TEST_F(CompileTFGraphTest, DoesNotInlineStatelessRandomOps) { @@ -261,6 +260,24 @@ TEST_F(CompileTFGraphTest, DoesNotInlineStatelessRandomOps) { ComputationProtoContains("tf.StatelessRandomNormal")); } +TEST_F(CompileTFGraphTest, TestRunsShapeInference) { + static constexpr char kShapeInferenceModule[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> () { + %0 = "tf.Const"() <{value = dense<-1> : tensor<3360x8xi32>}> : () -> tensor<3360x8xi32> + %cst_33 = "tf.Const"() <{value = dense<[1120, -1]> : tensor<2xi32>}> : () -> tensor<2xi32> + %cst_34 = "tf.Const"() <{value = dense<[3, 1120, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %cst_63 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %1965:4 = "tf._XlaHostComputeMlir"(%0, %cst_34, %cst_63, %cst_33) <{host_mlir_module = "#loc1 = loc(\22Reshape:\22)\0A#loc2 = loc(\22Reshape_4\22)\0A#loc3 = loc(\22Reshape\22)\0A#loc9 = loc(fused[#loc1, #loc2, #loc3])\0Amodule {\0A func.func @host_func(%arg0: tensor<3360x?xi32> loc(fused[#loc1, #loc2, #loc3]), %arg1: tensor<3xi32> loc(fused[#loc1, #loc2, #loc3]), %arg2: tensor loc(fused[#loc1, #loc2, #loc3]), %arg3: tensor<2xi32> loc(fused[#loc1, #loc2, #loc3])) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) {\0A %0 = \22tf.Reshape\22(%arg0, %arg1) {_xla_outside_compilation = \220\22} : (tensor<3360x?xi32>, tensor<3xi32>) -> tensor<3x1120x?xi32> loc(#loc9)\0A %1:3 = \22tf.Split\22(%arg2, %0) {_xla_outside_compilation = \220\22} : (tensor, tensor<3x1120x?xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1x1120x?xi32>) loc(#loc10)\0A %2 = \22tf.Reshape\22(%1#0, %arg3) {_xla_outside_compilation = \220\22} : (tensor<1x1120x?xi32>, tensor<2xi32>) -> tensor<1120x?xi32> loc(#loc11)\0A %3 = \22tf.Shape\22(%2) {_xla_outside_compilation = \220\22} : (tensor<1120x?xi32>) -> tensor<2xi32> loc(#loc12)\0A return %1#1, %1#2, %2, %3 : tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32> loc(#loc9)\0A } loc(#loc9)\0A} loc(#loc)\0A#loc = loc(unknown)\0A#loc4 = loc(\22Split:\22)\0A#loc5 = loc(\22split\22)\0A#loc6 = loc(\22Reshape_5\22)\0A#loc7 = loc(\22Shape:\22)\0A#loc8 = loc(\22Shape_4\22)\0A#loc10 = loc(fused[#loc4, #loc5])\0A#loc11 = loc(fused[#loc1, #loc6])\0A#loc12 = loc(fused[#loc7, #loc8])\0A", recv_key = "host_compute_channel_0_retvals", send_key = "host_compute_channel_0_args"}> : (tensor<3360x8xi32>, tensor<3xi32>, tensor, tensor<2xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) + return + } + } + )"; + + auto compilation_result = + CompileWithComputation(CreateTestMlirToHloArgs(kShapeInferenceModule)); + EXPECT_TRUE(compilation_result.ok()); +} } // namespace } // namespace v1 } // namespace tf2xla diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/testdata/prepare_to_library.mlir b/tensorflow/compiler/mlir/tf2xla/api/v1/testdata/prepare_to_library.mlir new file mode 100644 index 00000000000000..42e145effa742f --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/testdata/prepare_to_library.mlir @@ -0,0 +1,10 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> () { + %0 = "tf.Const"() <{value = dense<-1> : tensor<3360x8xi32>}> : () -> tensor<3360x8xi32> + %cst_33 = "tf.Const"() <{value = dense<[1120, -1]> : tensor<2xi32>}> : () -> tensor<2xi32> + %cst_34 = "tf.Const"() <{value = dense<[3, 1120, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %cst_63 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %1965:4 = "tf._XlaHostComputeMlir"(%0, %cst_34, %cst_63, %cst_33) <{host_mlir_module = "#loc1 = loc(\22Reshape:\22)\0A#loc2 = loc(\22Reshape_4\22)\0A#loc3 = loc(\22Reshape\22)\0A#loc9 = loc(fused[#loc1, #loc2, #loc3])\0Amodule {\0A func.func @host_func(%arg0: tensor<3360x?xi32> loc(fused[#loc1, #loc2, #loc3]), %arg1: tensor<3xi32> loc(fused[#loc1, #loc2, #loc3]), %arg2: tensor loc(fused[#loc1, #loc2, #loc3]), %arg3: tensor<2xi32> loc(fused[#loc1, #loc2, #loc3])) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) {\0A %0 = \22tf.Reshape\22(%arg0, %arg1) {_xla_outside_compilation = \220\22} : (tensor<3360x?xi32>, tensor<3xi32>) -> tensor<3x1120x?xi32> loc(#loc9)\0A %1:3 = \22tf.Split\22(%arg2, %0) {_xla_outside_compilation = \220\22} : (tensor, tensor<3x1120x?xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1x1120x?xi32>) loc(#loc10)\0A %2 = \22tf.Reshape\22(%1#0, %arg3) {_xla_outside_compilation = \220\22} : (tensor<1x1120x?xi32>, tensor<2xi32>) -> tensor<1120x?xi32> loc(#loc11)\0A %3 = \22tf.Shape\22(%2) {_xla_outside_compilation = \220\22} : (tensor<1120x?xi32>) -> tensor<2xi32> loc(#loc12)\0A return %1#1, %1#2, %2, %3 : tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32> loc(#loc9)\0A } loc(#loc9)\0A} loc(#loc)\0A#loc = loc(unknown)\0A#loc4 = loc(\22Split:\22)\0A#loc5 = loc(\22split\22)\0A#loc6 = loc(\22Reshape_5\22)\0A#loc7 = loc(\22Shape:\22)\0A#loc8 = loc(\22Shape_4\22)\0A#loc10 = loc(fused[#loc4, #loc5])\0A#loc11 = loc(fused[#loc1, #loc6])\0A#loc12 = loc(fused[#loc7, #loc8])\0A", recv_key = "host_compute_channel_0_retvals", send_key = "host_compute_channel_0_args"}> : (tensor<3360x8xi32>, tensor<3xi32>, tensor, tensor<2xi32>) -> (tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32>) + return + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc index 9d0b884ebbe85d..aa9a73215b7284 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.cc @@ -91,8 +91,12 @@ void AddTfDialectToExecutorPasses(OpPassManager &pm) { pm.addPass(mlir::createSymbolDCEPass()); if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_convert_control_to_data_outputs_pass) { + bool composite_tpuexecute_side_effects = + tensorflow::GetMlirCommonFlags() + ->tf_mlir_enable_composite_tpuexecute_side_effects; pm.addPass( - mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass()); + mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass( + composite_tpuexecute_side_effects)); } pm.addPass(mlir::TF::CreateVerifySuitableForExportPass()); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index 6749f012da37fb..76369ade79193f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -35,6 +35,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", "//tensorflow/compiler/mlir/tf2xla/api/v1:compile_tf_graph", + "//tensorflow/compiler/mlir/tf2xla/internal:compilation_timer", "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_mlir", "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_to_hlo", "//tensorflow/compiler/tf2xla:layout_util", @@ -52,6 +53,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:error_logging", @@ -76,12 +78,15 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", "@local_tsl//tsl/lib/monitoring:test_utils", "@local_tsl//tsl/platform:statusor", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index 4fa9b9bb98da3f..9729f755c611ce 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -160,19 +160,18 @@ void CreateTPUClusteringPipelineV2(OpPassManager &pm) { } tensorflow::Status RunFunctionTf2xlaClusteringBridge( - ModuleOp module, DeviceType device_type, bool is_in_fallback_enabled_mode, - llvm::StringRef module_name) { - bool is_replicated = device_type == DeviceType::XLA_TPU_JIT; + ModuleOp module, bool is_supported_by_replicated_brige, + bool is_in_fallback_enabled_mode, llvm::StringRef module_name) { std::string device_type_filter = - device_type == DeviceType::XLA_TPU_JIT ? "tpu" : "cpu/gpu"; + is_supported_by_replicated_brige ? "tpu" : "cpu/gpu"; VLOG(2) - << (is_replicated ? "Replicated" : "NonReplicated") + << (is_supported_by_replicated_brige ? "Replicated" : "NonReplicated") << " Bridge called stack trace is " << "(NOTE: this is not an error; rather the stack trace for debugging) : " << tensorflow::CurrentStackTrace(); Status clustering_status = - is_replicated + is_supported_by_replicated_brige ? RunTFXLABridge( module, [module_name](OpPassManager &pm) { @@ -187,12 +186,12 @@ tensorflow::Status RunFunctionTf2xlaClusteringBridge( }, module_name, /*dump_prefix=*/"tf_xla_bridge_v2_nonreplicated"); - // TODO(b/317798386): add is_replicated as a filter. + // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. TF_RETURN_IF_ERROR(RecordIfErrorStatus( /*error_prefix=*/"clustering_v2", is_in_fallback_enabled_mode, device_type_filter, clustering_status)); - // TODO(b/317798386): add is_replicated as a filter. + // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( device_type_filter, /*bridge_version=*/"v2", /*fallback_enabled=*/is_in_fallback_enabled_mode, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h index e1298ac53560d3..8963fe7b126b9e 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h @@ -39,17 +39,19 @@ namespace v2 { // Inputs: // module - The MLIR Module that will be clustered. Expected to be in TF // Executor Dialect or TF Functional Dialect. Will convert to TF Functional. -// . device_type - The device type to cluster for. -// is_in_fallback_enabled_mode - Whether this was called with fallback to the -// non-MLIR Bridge. This is just for logging purposes and doesn't affect -// logic. -// module_name - What the input module name is for debugging help. +// is_supported_by_replicated_brige - If the graph targets the replicated +// bridge. Set it to true for replicated/partitioned graphs. e.g. replicated +// and single-core TPU graphs. Set this to false if the graph is not +// replicated, e.g. CPU/GPU graphs. is_in_fallback_enabled_mode - Whether this +// was called with fallback to the non-MLIR Bridge. This is just for logging +// purposes and doesn't affect logic. module_name - What the input module name +// is for debugging help. // // Output: Modifies the input module in place with clustered operations. // status - Whether the transformation to cluster the input MLIR module was // successful. tensorflow::Status RunFunctionTf2xlaClusteringBridge( - mlir::ModuleOp module, DeviceType device_type, + mlir::ModuleOp module, bool is_supported_by_replicated_brige, bool is_in_fallback_enabled_mode, llvm::StringRef module_name = llvm::StringRef()); } // namespace v2 diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc index d00d8b43d9e790..c4a96702533c49 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc @@ -82,14 +82,14 @@ class FunctionClusterTensorflowDialectTest : public ::testing::Test { OwningOpRef mlir_module_; }; -TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfTPU) { +TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfReplicatedBridge) { CellReader compilation_status(kCompilationStreamz); TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ true, + /*is_in_fallback_enabled_mode=*/false)); FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); @@ -98,14 +98,15 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfTPU) { compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); } -TEST_F(FunctionClusterTensorflowDialectTest, RunsOutsideCompilationTPU) { +TEST_F(FunctionClusterTensorflowDialectTest, + RunsOutsideCompilationReplicatedBridge) { CellReader compilation_status(kCompilationStreamz); TF_ASSERT_OK(CreateMlirModule("outside_compilation.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ true, + /*is_in_fallback_enabled_mode=*/false)); FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); @@ -121,31 +122,14 @@ TEST_F(FunctionClusterTensorflowDialectTest, RunsOutsideCompilationTPU) { compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); } -TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFCPU) { +TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) { CellReader compilation_status(kCompilationStreamz); TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_CPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); - - FuncOp main = mlir_module_->lookupSymbol("main"); - ASSERT_TRUE(main); - - EXPECT_EQ( - compilation_status.Delta("cpu/gpu", "v2", "fallback_disabled", "success"), - 1); -} - -TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFGPU) { - CellReader compilation_status(kCompilationStreamz); - - TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_GPU_JIT, - /*is_in_fallback_enabled_mode=*/false)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ false, + /*is_in_fallback_enabled_mode=*/false)); FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); @@ -160,9 +144,9 @@ TEST_F(FunctionClusterTensorflowDialectTest, LogsFallbackMode) { TF_ASSERT_OK(CreateMlirModule("empty_func.mlir")); - TF_EXPECT_OK( - RunFunctionTf2xlaClusteringBridge(*mlir_module_, DeviceType::XLA_TPU_JIT, - /*is_in_fallback_enabled_mode=*/true)); + TF_EXPECT_OK(RunFunctionTf2xlaClusteringBridge( + *mlir_module_, /*is_supported_by_replicated_brige*/ true, + /*is_in_fallback_enabled_mode=*/true)); EXPECT_EQ( compilation_status.Delta("tpu", "v2", "fallback_enabled", "success"), 1); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc index 8bfb3308ffead2..d297e45b70e0bb 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc @@ -25,9 +25,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/variant.h" +#include "llvm/ADT/ScopeExit.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/compilation_timer.h" #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h" #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h" #include "tensorflow/compiler/tf2xla/layout_util.h" @@ -38,7 +40,9 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/util/debug_data_dumper.h" #include "tensorflow/core/util/dump_graph.h" +#include "tsl/lib/monitoring/sampler.h" #include "tsl/platform/error_logging.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -53,9 +57,17 @@ using tpu::FunctionToHloArgs; using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; +auto* phase2_bridge_compilation_time = tsl::monitoring::Sampler<1>::New( + {"/tensorflow/core/tf2xla/api/v2/phase2_compilation_time", + "The wall-clock time spent on executing graphs in milliseconds.", + "configuration"}, + // Power of 1.5 with bucket count 45 (> 23 hours) + {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); + // Name of component for error logging. This name is fixed and required to // enable logging. constexpr char kBridgeComponent[] = "TFXLABridge"; +constexpr char kFullBridge[] = "full_bridge"; namespace { @@ -73,6 +85,7 @@ void DumpComputationInput( if (!VLOG_IS_ON(2)) { return; } + switch (computation.index()) { case 0: VLOG(2) << "LegalizeMlirToHlo with MLIR computation input: " @@ -95,25 +108,28 @@ void DumpComputationInput( Status DumpHloCompilationResult(std::string_view name, XlaCompilationResult* compilation_result) { - if (VLOG_IS_ON(2)) { - TF_ASSIGN_OR_RETURN( - auto hlo_module_config, - xla::HloModule::CreateModuleConfigFromProto( - compilation_result->computation->proto(), xla::DebugOptions())); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - xla::HloModule::CreateFromProto( - compilation_result->computation->proto(), hlo_module_config)); - - std::string all_computations; - for (auto computation : hlo_module->computations()) { - all_computations += computation->ToString() + "\n\n"; - } - - tensorflow::DumpRawStringToFile(name, all_computations); + if (!VLOG_IS_ON(2) && + !DEBUG_DATA_DUMPER()->ShouldDump(std::string(name), kDebugGroupMain)) { + return OkStatus(); } + TF_ASSIGN_OR_RETURN( + auto hlo_module_config, + xla::HloModule::CreateModuleConfigFromProto( + compilation_result->computation->proto(), xla::DebugOptions())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + xla::HloModule::CreateFromProto(compilation_result->computation->proto(), + hlo_module_config)); + + std::string all_computations; + for (auto computation : hlo_module->computations()) { + all_computations += computation->ToString() + "\n\n"; + } + + tensorflow::DumpRawStringToFile(name, all_computations); + return OkStatus(); } @@ -129,6 +145,12 @@ tsl::StatusOr LegalizeMlirToHlo( std::vector* arg_core_mapping, std::vector>* per_core_arg_shapes, xla::CompileOnlyClient* client) { + CompilationTimer timer; + auto record_time = llvm::make_scope_exit([&timer] { + phase2_bridge_compilation_time->GetCell(kFullBridge) + ->Add(timer.ElapsedCyclesInMilliseconds()); + }); + auto compilation_result = std::make_unique(); DumpComputationInput(computation); @@ -140,7 +162,8 @@ tsl::StatusOr LegalizeMlirToHlo( arg_shapes, arg_core_mapping, per_core_arg_shapes, client, compilation_result.get())); - DumpHloCompilationResult("legalize_tf_fallback", compilation_result.get()) + DumpHloCompilationResult("legalize_tf_fallback.hlo", + compilation_result.get()) .IgnoreError(); return *compilation_result; } @@ -155,15 +178,15 @@ tsl::StatusOr LegalizeMlirToHlo( VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using " "Combined MLIR and XlaBuilder Bridge."; - DumpHloCompilationResult("legalize_tf_combined_bridge", + DumpHloCompilationResult("legalize_tf_combined_bridge.hlo", compilation_result.get()) .IgnoreError(); return *compilation_result; } - VLOG(1) - << "Failed to compile MLIR computation to XLA HLO using " - "Combined MLIR and XlaBuilder Bridge. Falling back to Graph Bridge."; + VLOG(1) << "Failed to compile MLIR computation to XLA HLO using Combined " + "MLIR and XlaBuilder Bridge. Falling back to MLIR tf2xla Bridge. " + << combined_bridge_status.status(); tsl::error_logging::Log(kBridgeComponent, "TFXLA_API_V2_COMBINED_BRIDGE", combined_bridge_status.status().ToString()) .IgnoreError(); @@ -176,21 +199,22 @@ tsl::StatusOr LegalizeMlirToHlo( if (mlir_bridge_status.ok()) { VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using MLIR " - "tf2xla bridge"; + "tf2xla Bridge"; IncrementTfMlirBridgeSecondPhaseCounter( MlirBridgeSecondPhaseMetric::kMlirWithFallbackModeSuccess); - DumpHloCompilationResult("legalize_tf_mlir_bridge", + DumpHloCompilationResult("legalize_tf_mlir_bridge.hlo", compilation_result.get()) .IgnoreError(); return *compilation_result; } else if (mlir_bridge_status.status() == CompileToHloGraphAnalysisFailedError()) { VLOG(1) << "Filtered out MLIR computation to XLA HLO using MLIR tf2xla " - "bridge. Falling back to Combined Bridge."; + "Bridge. Could not generate HLO."; } else { - VLOG(1) << "Failed to compile MLIR computation to XLA HLO using " - "MLIR Bridge. Falling back to Combined Bridge."; + VLOG(1) << "Failed to compile MLIR computation to XLA HLO using MLIR " + "tf2xla Bridge. Could not generate HLO. " + << mlir_bridge_status.status(); tsl::error_logging::Log(kBridgeComponent, "TFXLA_API_V2_PHASE2_MLIR_BRIDGE", mlir_bridge_status.status().ToString()) .IgnoreError(); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc index ab7aa70b9bcc74..99c213916f6f3f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc @@ -17,19 +17,26 @@ limitations under the License. #include #include +#include #include #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/core/lib/monitoring/test_utils.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/util/debug_data_dumper.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/monitoring/test_utils.h" #include "tsl/platform/statusor.h" @@ -38,14 +45,15 @@ namespace tf2xla { namespace v2 { using ::tensorflow::monitoring::testing::CellReader; +using ::testing::TestWithParam; using tpu::FunctionToHloArgs; using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; using tpu::TPUCompileMetadataProto; -using ::tsl::monitoring::testing::Histogram; static constexpr char kCompilationTimeStreamzName[] = "/tensorflow/core/tf2xla/api/v2/phase2_compilation_time"; +static constexpr char kFullBridge[] = "full_bridge"; static constexpr char kCompilationStatusStreamzName[] = "/tensorflow/core/tf2xla/api/v2/phase2_compilation_status"; static const char kMlirWithFallbackModeSuccess[] = @@ -106,6 +114,7 @@ tsl::StatusOr CompileMlirModule( std::vector arg_shapes; TPUCompileMetadataProto metadata_proto; + metadata_proto.add_retvals(); bool use_tuple_args = true; std::vector arg_core_mapping; std::vector> per_core_arg_shapes; @@ -131,6 +140,89 @@ TEST(LegalizeTFTest, RecordsStreamzForSuccessfulLegalizeWithMlirBridge) { EXPECT_EQ(compilation_status.Delta(kMlirWithFallbackModeFailure), 0); } +TEST(LegalizeTFTest, MatMul) { + static constexpr char kMatMulModuleStr[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> (tensor<5x11xf32>) { + %arg0 = "tf.Const"() {value = dense<-3.0> : tensor<5x7xf32>} : () -> tensor<5x7xf32> + %arg1 = "tf.Const"() {value = dense<-3.0> : tensor<11x7xf32>} : () -> tensor<11x7xf32> + + %1 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> + + func.return %1 : tensor<5x11xf32> + } + })"; + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + kMatMulModuleStr, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)); +} + +struct MatMulTestCase { + std::string mat_mul_method; +}; + +using BatchMatMulTest = TestWithParam; + +TEST_P(BatchMatMulTest, BatchMatMul) { + const MatMulTestCase& test_case = GetParam(); + static constexpr char kMatMulModuleStr[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> (tensor<1x4x4xf32>) { + %%arg0 = "tf.Const"() {value = dense<-3.0> : tensor<1x4x2xf32>} : () -> tensor<1x4x2xf32> + %%arg1 = "tf.Const"() {value = dense<-3.0> : tensor<1x2x4xf32>} : () -> tensor<1x2x4xf32> + + %%1 = "tf.%s"(%%arg0, %%arg1) {T = f32, adj_x = false, adj_y = false, grad_x = false, grad_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<1x2x4xf32>) -> tensor<1x4x4xf32> + + func.return %%1 : tensor<1x4x4xf32> + } + })"; + std::string mat_mul_method = + absl::StrFormat(kMatMulModuleStr, test_case.mat_mul_method); + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + mat_mul_method.c_str(), + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)); +} + +INSTANTIATE_TEST_SUITE_P( + BatchMatMulTest, BatchMatMulTest, + ::testing::ValuesIn({ + {"BatchMatMul"}, + {"BatchMatMulV2"}, + {"BatchMatMulV3"}, + }), + [](const ::testing::TestParamInfo& info) { + return info.param.mat_mul_method; + }); + +TEST(LegalizeTFTest, DumpsProducedHLO) { + Env* env = Env::Default(); + std::string test_dir = testing::TmpDir(); + setenv("TF_DUMP_GRAPH_PREFIX", test_dir.c_str(), /*overwrite=*/1); + setenv("TF_DUMP_GRAPH_NAME_FILTER", "*", 1); + DEBUG_DATA_DUMPER()->LoadEnvvars(); + + std::vector files; + TF_ASSERT_OK(env->GetChildren(test_dir, &files)); + int original_files_size = files.size(); + + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + kMlirModuleStr, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)); + + // Due to the shared test of this infrastructure, we just need to make sure + // that the dumped file size is greater than what was originally inside + // the test directory. + TF_ASSERT_OK(env->GetChildren(test_dir, &files)); + EXPECT_THAT(files.size(), ::testing::Gt(original_files_size)); + setenv("TF_DUMP_GRAPH_PREFIX", test_dir.c_str(), /*overwrite=*/0); +} + TEST(LegalizeTFTest, RecordsStreamzForFailedLegalizeWithMlirBridge) { CellReader compilation_status(kCompilationStatusStreamzName); @@ -203,6 +295,20 @@ TEST(LegalizeTFTest, RecordsStreamzForNoMlirFallback) { EXPECT_FALSE(compile_result.ok()); } +TEST(LegalizeTFTest, RecordsCompilationTimeForSuccessfulCompilation) { + CellReader compilation_time( + kCompilationTimeStreamzName); + + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + kMlirModuleStr, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED)); + + // Compilation time should have been updated. + EXPECT_GT(compilation_time.Delta(kFullBridge).num(), 0); +} + } // namespace v2 } // namespace tf2xla } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc index 455a59d6607c49..9befa9b7714a27 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.cc @@ -90,8 +90,12 @@ void AddTfDialectToExecutorPasses(OpPassManager &pm) { pm.addPass(mlir::createSymbolDCEPass()); if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_convert_control_to_data_outputs_pass) { + bool composite_tpuexecute_side_effects = + tensorflow::GetMlirCommonFlags() + ->tf_mlir_enable_composite_tpuexecute_side_effects; pm.addPass( - mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass()); + mlir::tf_executor::CreateTFExecutorConvertControlToDataOutputsPass( + composite_tpuexecute_side_effects)); } pm.addPass(mlir::TF::CreateVerifySuitableForExportPass()); } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD index f211357de28df9..1f96ecb120795d 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -243,3 +243,47 @@ tf_cc_test( "@local_tsl//tsl/platform:status", ], ) + +cc_library( + name = "mlir_bridge_pass_util", + srcs = ["mlir_bridge_pass_util.cc"], + hdrs = ["mlir_bridge_pass_util.h"], + visibility = ["//tensorflow/compiler/tf2xla:__pkg__"], + deps = [ + "//tensorflow/compiler/tf2xla:tf2xla_defs", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:function_body", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:status", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "mlir_bridge_pass_util_test", + srcs = ["mlir_bridge_pass_util_test.cc"], + deps = [ + ":mlir_bridge_pass_util", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:tf2xla_defs", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/platform:enable_tf2_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@local_tsl//tsl/lib/core:status_test_util", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index ed04adc6a394ff..f00df1513215cc 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -143,6 +143,8 @@ void AddReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm, pm.addNestedPass(mlir::TFDevice::CreateClusterConstantSinkingPass()); pm.addPass(mlir::TF::CreateResourceDeviceInferencePass()); + pm.addNestedPass( + tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass()); pm.addPass(mlir::TFDevice::CreateClusterOutliningPass()); pm.addPass(mlir::TFTPU::CreateTPUResourceReadForWritePass()); pm.addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass()); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc index c9cc5a4d1df16c..756ec42f31268c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc @@ -28,7 +28,7 @@ TEST(ClusteringBridgePassesTest, AddsBridgePasses) { OpPassManager pass_manager; AddReplicatedBridgeClusteringPipelinePasses(pass_manager); - EXPECT_EQ(pass_manager.size(), 43); + EXPECT_EQ(pass_manager.size(), 44); } TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc index 3b11eaaf2287d6..4fd0c21d68331f 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/tpu_compile.h" -#include "tsl/lib/monitoring/sampler.h" #include "tsl/platform/error_logging.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -52,12 +51,6 @@ limitations under the License. namespace tensorflow { namespace tf2xla { namespace internal { -auto* phase2_bridge_compilation_time = tsl::monitoring::Sampler<1>::New( - {"/tensorflow/core/tf2xla/api/v2/phase2_compilation_time", - "The wall-clock time spent on executing graphs in milliseconds.", - "configuration"}, - // Power of 1.5 with bucket count 45 (> 23 hours) - {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); // Name of component for error logging. This name is fixed and required to // enable logging. @@ -126,20 +119,11 @@ tsl::StatusOr LegalizeWithMlirBridge( // Enabling op fallback also enables whole graph fallback if op by op // fallback failed. - tsl::StatusOr mlir_bridge_status; - { - CompilationTimer timer; - const std::string kMlirBridgeFallback = "mlir_bridge_op_fallback_enabled"; - - mlir_bridge_status = CompileFromMlirToXlaHlo( - /*lower_to_xla_hlo=*/true, computation, metadata, device_type, - shape_determination_fns, use_tuple_args, compilation_result, - custom_legalization_passes, arg_shapes, arg_core_mapping, - per_core_arg_shapes); - - phase2_bridge_compilation_time->GetCell(kMlirBridgeFallback) - ->Add(timer.ElapsedCyclesInMilliseconds()); - } + tsl::StatusOr mlir_bridge_status = CompileFromMlirToXlaHlo( + /*lower_to_xla_hlo=*/true, computation, metadata, device_type, + shape_determination_fns, use_tuple_args, compilation_result, + custom_legalization_passes, arg_shapes, arg_core_mapping, + per_core_arg_shapes); if (mlir_bridge_status.ok()) { VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using MLIR " diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc index 8d7531f87eed3e..2d321246463494 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h" #include +#include #include #include "absl/log/log.h" @@ -33,19 +34,11 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tsl/lib/monitoring/sampler.h" #include "tsl/platform/statusor.h" namespace tensorflow { namespace tf2xla { namespace internal { -auto* phase2_combined_bridge_compilation_time = - tsl::monitoring::Sampler<1>::New( - {"/tensorflow/core/tf2xla/api/v2/phase2_combined_compilation_time", - "The wall-clock time spent on combined graphs in milliseconds.", - "configuration"}, - // Power of 1.5 with bucket count 45 (> 23 hours) - {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); using metrics::IncrementTfMlirBridgeSecondPhaseCounter; using metrics::MlirBridgeSecondPhaseMetric; @@ -63,14 +56,14 @@ tsl::StatusOr LegalizeTfToHlo( xla::CompileOnlyClient* client, XlaCompilationResult* compilation_result) { LOG_FIRST_N(INFO, 1) << "Compiling MLIR computation to XLA HLO using the " "Combined MLIR Tf2Xla Bridge."; - CompilationTimer timer; - constexpr char kCombinedBridgeTimer[] = "combined_bridge"; - auto mlir_compilation = internal::CompileFromMlirToXlaHlo( - /*lower_to_xla_hlo=*/false, computation, metadata, device_type, - shape_determination_fns, use_tuple_args, compilation_result, - custom_legalization_passes, arg_shapes, arg_core_mapping, - per_core_arg_shapes); + tsl::StatusOr mlir_compilation + + = internal::CompileFromMlirToXlaHlo( + /*lower_to_xla_hlo=*/false, computation, metadata, device_type, + shape_determination_fns, use_tuple_args, compilation_result, + custom_legalization_passes, arg_shapes, arg_core_mapping, + per_core_arg_shapes); if (!mlir_compilation.ok()) { IncrementTfMlirBridgeSecondPhaseCounter( @@ -94,8 +87,6 @@ tsl::StatusOr LegalizeTfToHlo( IncrementTfMlirBridgeSecondPhaseCounter( MlirBridgeSecondPhaseMetric::kMlirCombinedOldSuccess); - phase2_combined_bridge_compilation_time->GetCell(kCombinedBridgeTimer) - ->Add(timer.ElapsedCyclesInMilliseconds()); return *compilation_result; } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc new file mode 100644 index 00000000000000..3672b0bbe0f47a --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.cc @@ -0,0 +1,177 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h" + +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/tf2xla/tf2xla_defs.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" +#include "tensorflow/core/graph/graph.h" +#include "tsl/platform/status.h" + +namespace tensorflow { + +using ::mlir::failure; +using ::mlir::LogicalResult; +using ::mlir::success; + +namespace { +LogicalResult HasAttr( + const Graph& graph, const FunctionLibraryDefinition* function_library, + const std::function& predicate) { + if (predicate(graph)) { + return success(); + } + + // Check if any reachable functions from the graph has the target attribute. + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + if (!function_library) return failure(); + for (const std::string& func_name : + function_library->ReachableDefinitions(graph_def).ListFunctionNames()) { + const FunctionDef* func_def = function_library->Find(func_name); + std::unique_ptr func_body; + absl::Status status = FunctionDefToBodyHelper( + *func_def, AttrSlice(&func_def->attr()), function_library, &func_body); + // This is not expected to happen in practice + if (!status.ok()) { + LOG(ERROR) << "Failed to parse " << func_name << ": " + << tsl::NullTerminatedMessage(status); + return failure(); + } + if (predicate(*func_body->graph)) { + return success(); + } + } + return failure(); +} + +bool IsNonReplicatedGraph(const Graph& graph, + const FunctionLibraryDefinition* function_library) { + auto predicate = [](const Graph& graph) { + const std::string kStatefulPartitionedCallOp = "StatefulPartitionedCall"; + for (const Node* node : graph.nodes()) { + auto node_op = node->type_string(); + if (node_op == kStatefulPartitionedCallOp) { + // Functions called by StatefulfulPartitionedCall ops with + // _XlaMustCompile=true are compiled by XLA. + auto attr = node->attrs().FindByString(std::string(kMustCompileAttr)); + if (attr != nullptr && attr->b() == true) { + return true; + } + } + } + return false; + }; + return HasAttr(graph, function_library, predicate).succeeded(); +} + +bool IsReplicatedGraph(const Graph& graph, + const FunctionLibraryDefinition* function_library) { + auto predicate = [](const Graph& graph) { + for (const Node* node : graph.nodes()) { + // _tpu_replicate is used in replicated TPU graphs. It will be converted + // to_replication_info and _xla_compile_device_type in phase 1 pipelines. + if (node->attrs().FindByString(std::string(kTpuReplicateAttr))) { + return true; + } + } + return false; + }; + return HasAttr(graph, function_library, predicate).succeeded(); +} + +bool IsSingleCoreTpuGraph(const Graph& graph, + const FunctionLibraryDefinition* function_library) { + auto predicate = [](const Graph& graph) { + for (const Node* node : graph.nodes()) { + // _xla_compile_device_type=TPU is found in single-core TPU graphs. + auto attr = + node->attrs().FindByString(std::string(kCompileDeviceTypeAttr)); + if (attr && attr->s() == kTpuDevice) { + return true; + } + } + return false; + }; + return HasAttr(graph, function_library, predicate).succeeded(); +} + +bool IsReplicatedGraph(mlir::ModuleOp module) { + auto walk_result = module.walk([&](mlir::Operation* op) { + // TODO(b/223677572): Once the scope for new compilation and replication + // markers is expanded beyond bridge we can remove this check for + // `kTPUReplicateAttr`, we will then always have a `kCompileDeviceTypeAttr` + // in such cases (see above). + // TODO(b/229028654): Remove string conversion once we have C++17. + const llvm::StringRef tpu_replicate_attr_name(kTpuReplicateAttr.data(), + kTpuReplicateAttr.size()); + auto replicate_attr = + op->getAttrOfType(tpu_replicate_attr_name); + if (replicate_attr) return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); + }); + return walk_result.wasInterrupted(); +} + +bool IsSingleCoreTPUGraph(mlir::ModuleOp module) { + auto walk_result = module.walk([&](mlir::Operation* op) { + // Check for ops with compile device type "TPU". This allows us to support + // TPU compilation without replication. Note that currently the compile + // device type is not set by default before bridge, only if eager context + // attribute `jit_compile_rewrite` is true. + // TODO(b/229028654): Remove string conversion once we have C++17. + const llvm::StringRef compile_device_type_attr_name( + kCompileDeviceTypeAttr.data(), kCompileDeviceTypeAttr.size()); + auto compilation_attr = + op->getAttrOfType(compile_device_type_attr_name); + if (compilation_attr && compilation_attr.getValue().str() == kTpuDevice) { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + return walk_result.wasInterrupted(); +} + +} // namespace + +bool IsSupportedByNonReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library) { + return IsNonReplicatedGraph(graph, function_library); +} + +bool IsSupportedByReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library) { + return IsReplicatedGraph(graph, function_library) || + IsSingleCoreTpuGraph(graph, function_library); +} + +bool IsSupportedByReplicatedBridge(mlir::ModuleOp module) { + return IsReplicatedGraph(module) || IsSingleCoreTPUGraph(module); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h new file mode 100644 index 00000000000000..5ea0bdc71ea0d0 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { + +// Checks if a graph or reachable functions in the library have any +// StatefulPartitionedOps with _XlaMustCompile=true. The function library will +// be skipped if nullptr is provided. +bool IsSupportedByNonReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library); + +// Checks if a graph or reachable functions in the library have any ops with +// _tpu_replicate or _xla_compile_device_type=TPU. The function library will be +// skipped if nullptr is provided. + +bool IsSupportedByReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library); + +// Check if an MLIR module has any ops with _tpu_replicate or +// _xla_compile_device_type=TPU. +bool IsSupportedByReplicatedBridge(mlir::ModuleOp module); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc new file mode 100644 index 00000000000000..8ce19560f1f349 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h" + +#include + +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/tf2xla/tf2xla_defs.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/enable_tf2_utils.h" +#include "tsl/lib/core/status_test_util.h" + +namespace tensorflow { + +namespace { + +FunctionDef OuterXTimesTwo() { + return FunctionDefHelper::Define( + // Name + "OuterXTimesTwo", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attr def + {}, + {{{"y"}, + "StatefulPartitionedCall", + {"x"}, + {{"Tin", DataTypeSlice{DT_FLOAT}}, + {"Tout", DataTypeSlice{DT_FLOAT}}, + {"f", + FunctionDefHelper::FunctionRef("XTimesTwoFloat", {{"T", DT_FLOAT}})}, + {std::string(kMustCompileAttr), true}}}}); +} + +TEST(IsSupportedByNonReplicatedBridge, NonReplicatedGraph) { + const FunctionDef& fd = test::function::XTimesTwo(); + FunctionDefLibrary flib; + *flib.add_function() = fd; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + tensorflow::set_tf2_execution(true); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + + Node* call; + NameAttrList f_name_attr; + f_name_attr.set_name(fd.signature().name()); + TF_ASSERT_OK( + NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) + .Input(inputs) + .Attr("Tin", {DT_FLOAT}) + .Attr("Tout", {DT_FLOAT}) + .Attr("f", f_name_attr) + .Finalize(root.graph(), &call)); + call->AddAttr(std::string(kMustCompileAttr), true); + + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_TRUE( + IsSupportedByNonReplicatedBridge(graph, /*function_library=*/nullptr)); +} + +// Checks that HasAttr actually goes through function library. +TEST(IsSupportedByNonReplicatedBridge, NonReplicatedFunctionLibrary) { + const FunctionDef& fd = OuterXTimesTwo(); + FunctionDefLibrary flib; + *flib.add_function() = fd; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + tensorflow::set_tf2_execution(true); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + + // Builds a call without compilation markers that calls a function with Xla + // clusters. + Node* call; + NameAttrList f_name_attr; + f_name_attr.set_name(fd.signature().name()); + TF_ASSERT_OK( + NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) + .Input(inputs) + .Attr("Tin", {DT_FLOAT}) + .Attr("Tout", {DT_FLOAT}) + .Attr("f", f_name_attr) + .Finalize(root.graph(), &call)); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_TRUE( + IsSupportedByNonReplicatedBridge(graph, /*function_library=*/&flib_def)); +} + +TEST(IsSupportedByReplicatedBridge, ReplicatedGraph) { + const FunctionDef& fd = test::function::XTimesTwo(); + FunctionDefLibrary flib; + *flib.add_function() = fd; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + tensorflow::set_tf2_execution(true); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + + Node* call; + NameAttrList f_name_attr; + f_name_attr.set_name(fd.signature().name()); + TF_ASSERT_OK( + NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) + .Input(inputs) + .Attr("Tin", {DT_FLOAT}) + .Attr("Tout", {DT_FLOAT}) + .Attr("f", f_name_attr) + .Finalize(root.graph(), &call)); + call->AddAttr(std::string(kTpuReplicateAttr), "cluster"); + + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_TRUE( + IsSupportedByReplicatedBridge(graph, /*function_library=*/nullptr)); +} + +TEST(IsSupportedByReplicatedBridge, SingleCoreTpuGraph) { + const FunctionDef& fd = test::function::XTimesTwo(); + FunctionDefLibrary flib; + *flib.add_function() = fd; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + tensorflow::set_tf2_execution(true); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0); + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + + Node* call; + NameAttrList f_name_attr; + f_name_attr.set_name(fd.signature().name()); + TF_ASSERT_OK( + NodeBuilder("B", "StatefulPartitionedCall", &root.graph()->flib_def()) + .Input(inputs) + .Attr("Tin", {DT_FLOAT}) + .Attr("Tout", {DT_FLOAT}) + .Attr("f", f_name_attr) + .Finalize(root.graph(), &call)); + call->AddAttr(std::string(kCompileDeviceTypeAttr), kTpuDevice); + + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_TRUE( + IsSupportedByReplicatedBridge(graph, /*function_library=*/nullptr)); +} + +TEST(IsSupportedByReplicatedBridge, ReplicatedModule) { + const char* const code = R"mlir( +func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.Identity"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> (tensor) + func.return %0 : tensor +} +)mlir"; + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::parseSourceString(code, &context); + ASSERT_TRUE(module); + EXPECT_TRUE(IsSupportedByReplicatedBridge(*module)); +} + +TEST(IsSupportedByReplicatedBridge, SingleCoreTpuModule) { + const char* const code = R"mlir( +func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.Identity"(%arg0) {_xla_compile_device_type = "TPU"} : (tensor) -> (tensor) + func.return %0 : tensor +} +)mlir"; + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::parseSourceString(code, &context); + ASSERT_TRUE(module); + EXPECT_TRUE(IsSupportedByReplicatedBridge(*module)); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index 6250f2cf0ca7c5..df25fdd9ffe4be 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -26,6 +26,7 @@ cc_library( deps = [ ":extract_head_tail_outside_compilation", ":extract_outside_compilation", + ":hoist_broadcast_read", ":mark_ops_for_outside_compilation", ":tpu_cluster_formation", ":verify_clustering_pass", @@ -351,6 +352,41 @@ cc_library( ], ) +cc_library( + name = "hoist_broadcast_read", + srcs = ["hoist_broadcast_read.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalization_op_config", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + tf_cc_test( name = "tpu_cluster_formation_test", srcs = ["tpu_cluster_formation_test.cc"], diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h index ea6187a2309205..3ccf990a4b5272 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -56,6 +56,11 @@ CreateXlaOutlineEntryFunctionsPass(); std::unique_ptr> CreateMarkOpsForOutsideCompilationPass(); +// Creates a pass that hoists reads out of a replicate that are on a variable +// whose value is broacast to all replicas. +std::unique_ptr> +CreateHoistBroadcastReadPass(); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td index c219c35842c401..90d2e962bc9b1c 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -349,3 +349,42 @@ def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation let constructor = "tensorflow::tf2xla::internal::CreateMarkOpsForOutsideCompilationPass()"; } + +def HoistBroadcastReadPass : Pass<"tf-hoist-broadcast-read", "mlir::func::FuncOp"> { + let summary = "Hoist reads out of a replicate that are on a resource that is broacast to all replicas."; + + let description = [{ + Some `ReadVariableOp`s that are within a `tf_device.replicate` read the same + value across all replicas. These reads can be hoisted out of the + `tf_device.replicate` so there's one read for all replicas, and each replica + depends on the result of the read. This transform enables the + xla-broadcast-pass to optimize the broadcast value. + + For example, the following: + + ```mlir + tf_device.replicate {n = 2 : i32} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + "tf.OpA"(%0) : (tensor) -> () + } + ``` + + will be transformed into: + + ``mlir + %0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor + tf_device.replicate {n = 2 : i32} { + "tf.OpA"(%0) : (tensor) -> () + } + ``` + + We must ensure that there is a single underlying resource that not + distributed across replicas. There is a single underlying resource when the + resource device type is CPU, so we cautiously only apply in this case. + + To be cautious we never hoist a read that comes after a write to the same + resource. + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateHoistBroadcastReadPass()"; +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc new file mode 100644 index 00000000000000..732bae8c67b018 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc @@ -0,0 +1,154 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +using mlir::BlockArgument; +using mlir::failure; +using mlir::LogicalResult; +using mlir::Operation; +using mlir::OperationPass; +using mlir::OpOperand; +using mlir::StringAttr; +using mlir::success; +using mlir::Value; +using mlir::WalkResult; +using mlir::func::FuncOp; +using mlir::TF::ReadVariableOp; +using mlir::tf_device::ReplicateOp; + +#define GEN_PASS_DEF_HOISTBROADCASTREADPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" + +constexpr char kFuncDeviceAttr[] = "tf.device"; +constexpr char kCpuDeviceType[] = "CPU"; + +struct HoistBroadcastRead + : public impl::HoistBroadcastReadPassBase { + void runOnOperation() override; +}; + +// Get the ancestor of `descendant` that is a direct child of `ancestor`. +Operation* GetAncestorBelow(Operation* descendant, Operation* ancestor) { + Operation* parent = descendant->getParentOp(); + if (!parent) return nullptr; + if (parent == ancestor) return descendant; + return GetAncestorBelow(parent, ancestor); +} + +// `is_cpu_read` is set to `true` iff `read` is on a resource with device type +// CPU. +LogicalResult IsCpuRead(FuncOp func, ReadVariableOp read, bool& is_cpu_read) { + if (auto arg = read->getOperand(0).dyn_cast()) { + if (arg.getOwner() != &(func.front())) { + is_cpu_read = false; + return success(); + } + if (auto attr = func.getArgAttrOfType(arg.getArgNumber(), + kFuncDeviceAttr)) { + std::string device = attr.getValue().str(); + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(device, &parsed_name)) { + return read->emitOpError() << "invalid device '" << device << "'"; + } + is_cpu_read = parsed_name.type == kCpuDeviceType; + return success(); + } + } + is_cpu_read = false; + return success(); +} + +// Get the reads to hoist in the `replicate`. +LogicalResult GetReads(FuncOp func, ReplicateOp replicate, + llvm::SmallVector& reads) { + for (Operation& op : replicate.getBody().front()) { + if (auto read = llvm::dyn_cast(&op)) { + bool is_cpu_read; + if (failed(IsCpuRead(func, read, is_cpu_read))) return failure(); + if (is_cpu_read) reads.push_back(read); + } + } + return success(); +} + +// Move reads above the `replicate`. Skip reads that come after a write to the +// same resource. +void MoveReads(ReplicateOp replicate, + llvm::SmallVector& reads) { + for (ReadVariableOp read : reads) { + Value res = read.getResource(); + Operation* scope = res.getParentBlock()->getParentOp(); + if (!scope->isProperAncestor(replicate)) continue; + bool has_conflicting_write = false; + for (OpOperand& use : res.getUses()) { + Operation* using_op = use.getOwner(); + if (using_op == read) continue; + if (!replicate->isProperAncestor(using_op)) continue; + Operation* peer = GetAncestorBelow(using_op, replicate); + if (read->isBeforeInBlock(peer)) continue; + if (llvm::isa(peer)) continue; + has_conflicting_write = true; + } + if (has_conflicting_write) continue; + read->moveBefore(replicate); + } +} + +// Hoist `ReadVariableOp`s above the `tf_device.replicate`s. +void HoistBroadcastRead::runOnOperation() { + FuncOp func = getOperation(); + + auto result = func.walk([&](ReplicateOp replicate) { + llvm::SmallVector reads; + if (failed(GetReads(func, replicate, reads))) + return WalkResult::interrupt(); + MoveReads(replicate, reads); + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // namespace + +std::unique_ptr> CreateHoistBroadcastReadPass() { + return std::make_unique(); +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir index 7ba98798c126df..c098bf494272fd 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.mlir @@ -3,7 +3,7 @@ func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // expected-error@below {{op is in dialect chlo not in tf functional dialect}} - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir index 5a6fda697d23fa..7d88228447e4af 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_input_dialect_to_executor_pass_test.mlir @@ -11,6 +11,10 @@ func.func @testNoClusterFuncOpPasses(%arg0: tensor<4x?x!tf_type.stringref>) -> t // ----- +func.func @_func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} + func.func @testClusterFuncOpFails(%arg0: tensor) -> tensor { // expected-error@below {{failed TF functional to executor validation, op tf_device.cluster_func is not allowed}} %cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor) -> tensor @@ -29,6 +33,6 @@ func.func @testTFDialect(%arg0: tensor<4x?x!tf_type.stringref>) -> tensor<4x2x!t func.func @testNotTfDialect(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // expected-error@below {{op is in dialect chlo which is not an accepted dialect}} - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc index 3f35813744c60d..f00e12690d380b 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc @@ -24,7 +24,8 @@ namespace tensorflow { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, bool run_tpu_bridge, + std::optional config_proto, + bool is_supported_by_replicated_brige, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats) { switch (GetMlirBridgeRolloutState(config_proto)) { diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h index 5c7f47a219e10e..66a68ae53f535f 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h @@ -53,7 +53,8 @@ enum class MlirBridgeRolloutPolicy { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, bool run_tpu_bridge, + std::optional config_proto, + bool is_supported_by_replicated_brige, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats); diff --git a/tensorflow/compiler/mlir/tf2xla/tests/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/BUILD index c68c485954de1b..97bb01c30d1855 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow:tensorflow.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir index 11cfffc24eaa33..d79804a6d38cc6 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir @@ -119,7 +119,9 @@ func.func @send_to_host(%arg0: tensor) { // CHECK: [[INIT_TOKEN:%.*]] = mhlo.create_token // CHECK: "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: is_host_transfer = true // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_key_dtoh_0"} // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" @@ -137,7 +139,9 @@ func.func @recv_from_host() -> tensor { // CHECK: [[INIT_TOKEN:%.*]] = mhlo.create_token // CHECK: [[RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: is_host_transfer = true // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_key_htod_0"} // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" @@ -158,21 +162,29 @@ func.func @multiple_consecutive_ops(%arg0: tensor) -> tensor { // CHECK: [[INIT_TOKEN:%.*]] = mhlo.create_token // CHECK: [[SEND0_ARG0_TOKEN:%.*]] = "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send0_dtoh_0"} // CHECK: [[RECV0_RETVAL0_TUPLE:%.*]]:2 = "mhlo.recv"([[SEND0_ARG0_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv0_htod_0"} %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "recv0", send_key = "send0", host_mlir_module = ""} : (tensor) -> tensor // CHECK: [[SEND1_ARG0_TOKEN:%.*]] = "mhlo.send"([[RECV0_RETVAL0_TUPLE]]#0, [[RECV0_RETVAL0_TUPLE]]#1) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send1_dtoh_0"} // CHECK: [[RECV1_RETVAL0_TUPLE:%.*]]:2 = "mhlo.recv"([[SEND1_ARG0_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv1_htod_0"} %1 = "tf._XlaHostComputeMlir"(%0) {recv_key = "recv1", send_key = "send1", host_mlir_module = ""} : (tensor) -> tensor @@ -376,11 +388,15 @@ func.func @if_both_branches(%arg0: tensor, %arg1: tensor, %arg2: tensor // CHECK: [[IF:%.*]]:2 = "mhlo.if"([[ARG0]]) %0 = "mhlo.if"(%arg0) ({ // CHECK: [[TRUE_SEND_TOKEN:%.*]] = "mhlo.send"([[ARG1]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} // CHECK: [[TRUE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[TRUE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_if_true", send_key = "send_if_true", host_mlir_module = ""} : (tensor) -> tensor @@ -388,11 +404,15 @@ func.func @if_both_branches(%arg0: tensor, %arg1: tensor, %arg2: tensor "mhlo.return"(%1) : (tensor) -> () }, { // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[ARG2]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} // CHECK: [[FALSE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[FALSE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg2) {recv_key = "recv_if_false", send_key = "send_if_false", host_mlir_module = ""} : (tensor) -> tensor @@ -419,11 +439,15 @@ func.func @if_true_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} // CHECK: [[TRUE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[TRUE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_if_true", send_key = "send_if_true", host_mlir_module = ""} : (tensor) -> tensor @@ -456,11 +480,15 @@ func.func @if_false_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor< "mhlo.return"(%arg1) : (tensor) -> () }, { // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[ARG2]], [[INIT_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} // CHECK: [[FALSE_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[FALSE_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg2) {recv_key = "recv_if_false", send_key = "send_if_false", host_mlir_module = ""} : (tensor) -> tensor @@ -681,11 +709,15 @@ func.func @while_cond_body(%arg0: tensor) -> tensor { %0 = "mhlo.while"(%arg0) ({ ^bb0(%arg1: tensor): // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} // CHECK: [[COND_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[COND_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", host_mlir_module = ""} : (tensor) -> tensor @@ -697,11 +729,15 @@ func.func @while_cond_body(%arg0: tensor) -> tensor { }, { ^bb0(%arg1: tensor): // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} // CHECK: [[BODY_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[BODY_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", host_mlir_module = ""} : (tensor) -> tensor @@ -727,11 +763,15 @@ func.func @while_cond(%arg0: tensor) -> tensor { ^bb0(%arg1: tensor): // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} // CHECK: [[COND_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[COND_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", host_mlir_module = ""} : (tensor) -> tensor @@ -772,11 +812,15 @@ func.func @while_body(%arg0: tensor) -> tensor { ^bb0(%arg1: tensor): // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[ITER_ARG_VALUE]], [[ITER_ARG_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 2 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} // CHECK: [[BODY_RECV_TUPLE:%.*]]:2 = "mhlo.recv"([[BODY_SEND_TOKEN]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME: handle = + // CHECK-SAME: type = 3 // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", host_mlir_module = ""} : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir index 44997584147a0f..8e2876f5707deb 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-quant.mlir @@ -368,7 +368,7 @@ func.func @uniform_quantized_add(%arg0: tensor<3x2x!tf_type.qint32>) -> tensor<3 // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> - // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: %[[RES:.*]] = chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> // CHECK: %[[RES_INT:.*]] = mhlo.convert %[[RES]] : (tensor<3x2x!quant.uniform>) -> tensor<3x2xi32> @@ -418,10 +418,10 @@ func.func @uniform_quantized_clip_by_value(%input: tensor<3x2xf32>) -> tensor<3x // CHECK-DAG: %[[CONVERT_1:.*]] = mhlo.convert %[[OPERAND]] : (tensor<3x2x!quant.uniform>) -> tensor<3x2xi32> // CHECK-DAG: %[[CONVERT_2:.*]] = mhlo.convert %[[CONVERT_1]] : (tensor<3x2xi32>) -> tensor<3x2x!quant.uniform> - // CHECK: %[[MIN_CLIPPED:.*]] = chlo.broadcast_maximum %[[CONVERT_2]], %[[MIN_MAX]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: %[[MIN_CLIPPED:.*]] = chlo.broadcast_maximum %[[CONVERT_2]], %[[MIN_MAX]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> - // CHECK: chlo.broadcast_minimum %[[MIN_CLIPPED]], %[[MIN_MAX]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK: chlo.broadcast_minimum %[[MIN_CLIPPED]], %[[MIN_MAX]] {broadcast_dimensions = array} : // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) // CHECK-SAME: -> tensor<3x2x!quant.uniform> %1 = "tf.UniformQuantizedClipByValue"(%0, %min, %max, %scales, %zps) { @@ -557,4 +557,4 @@ func.func @while_region_with_quant_two_args(%arg0: tensor<2x2xf32>, %arg1: tenso // return %[[RESULT0]], %[[RESULT1]] func.return %3, %4 : tensor<2x?xf32>, tensor -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index ba0f70e369f3bf..018046d66d57a2 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -166,7 +166,7 @@ func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -218,7 +218,7 @@ func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -299,7 +299,7 @@ func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -381,7 +381,7 @@ func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = array} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = mhlo.rsqrt %[[add]] : tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -739,7 +739,7 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten // CHECK-LABEL: func @erf func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: chlo.erf %arg0 : tensor<2x3xf32> + // CHECK: mhlo.erf %arg0 : tensor<2x3xf32> %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> func.return %0 : tensor<2x3xf32> } @@ -787,14 +787,14 @@ func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @floordiv_broadcast_i32 func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] @@ -808,14 +808,14 @@ func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]] - // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] @@ -865,14 +865,14 @@ func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) // CHECK-LABEL: func @floordiv_dynamic func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = #chlo} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] @@ -886,7 +886,7 @@ func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> ten // CHECK-LABEL: func @floordiv_unsigned func.func @floordiv_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = array} // CHECK: return [[DIV]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor func.return %0: tensor @@ -926,12 +926,12 @@ func.func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] @@ -945,15 +945,15 @@ func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3x // CHECK-LABEL: func @floormod_broadcast_denominator func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -964,7 +964,7 @@ func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor< // CHECK-LABEL: func @floormod_unsigned_broadcast_denominator func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-NEXT: return [[REM]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> func.return %0: tensor<2x3xui32> @@ -974,15 +974,15 @@ func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg // CHECK-LABEL: func @floormod_dynamic_broadcast_numerator func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} - // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = #chlo} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = array, comparison_direction = #chlo} // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = array} // CHECK-DAG: [[SELECT:%.+]] = mhlo.select [[AND]], [[ADD]], [[REM]] // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -994,12 +994,12 @@ func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: // CHECK-LABEL: func @floormod_dynamic_broadcast_denominator func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-NOT: tf.FloorMod - // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor) -> tensor + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = array} : (tensor, tensor) -> tensor // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} : (tensor, tensor) -> tensor + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = array, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] : (tensor, tensor) -> tensor // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] : (tensor, tensor) -> tensor @@ -1839,8 +1839,8 @@ func.func @elu_unranked(%arg0: tensor) -> tensor { func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = #chlo} - // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = array, comparison_direction = #chlo} + // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = array} // CHECK-DAG: %[[MULGRAD:.*]] = mhlo.multiply %[[GRADIENTS]], %[[ADD1]] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %[[GRADIENTS]], %[[MULGRAD]] // CHECK: return %[[RESULT]] @@ -1857,7 +1857,7 @@ func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> // CHECK-LABEL: func @relu func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> func.return %0: tensor<1xi32> } @@ -1867,7 +1867,7 @@ func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func.func @relu_unranked(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor func.return %0: tensor } @@ -1877,7 +1877,7 @@ func.func @relu_unranked(%arg0: tensor) -> tensor { // CHECK-LABEL: func @relu_unsigned func.func @relu_unsigned(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = array} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor func.return %0: tensor } @@ -2017,7 +2017,7 @@ func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> te // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> - // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = array} : (tensor, tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> @@ -2775,7 +2775,7 @@ func.func @sigmoid_grad_complex(%arg0: tensor<2xcomplex>, %arg1: tensor<2xc // CHECK-LABEL: @sigmoid_grad_dynamic func.func @sigmoid_grad_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor - // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = array} : (tensor, tensor) -> tensor // CHECK: chlo.broadcast_multiply {{.*}} : (tensor, tensor) -> tensor %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor, tensor) -> tensor func.return %0 : tensor @@ -3662,7 +3662,7 @@ func.func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = array} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[MEAN]] : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -3698,7 +3698,7 @@ func.func @mean_dynamic(%arg0: tensor) -> tensor { // CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[MUL]] : index to i64 // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = array} : (tensor, tensor) -> tensor // CHECK: %[[MEAN_CONVERTED:.*]] = mhlo.convert %[[MEAN]] : (tensor) -> tensor // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor -> tensor<1xindex> // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -4120,8 +4120,8 @@ func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func.func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = array} + // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> func.return %3 : tensor<5xf32> } @@ -4142,8 +4142,8 @@ func.func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -4166,8 +4166,8 @@ func.func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tens // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -4186,8 +4186,8 @@ func.func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = array} + // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> @@ -5334,7 +5334,7 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-SAME: -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x3x5x7xf16> @@ -5360,7 +5360,7 @@ func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7 // CHECK-SAME: -> tensor<2x4x3x5x7xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x4x3x5x7xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x4x3x5x7xf16> @@ -5386,7 +5386,7 @@ func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x // CHECK-SAME: -> tensor<2x7x3x5xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x7x3x5xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x7x3x5xf16> @@ -5412,7 +5412,7 @@ func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf // CHECK-SAME: -> tensor<2x7x4x3x5xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<2x7x4x3x5xf32> // CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] // CHECK-SAME: -> tensor<2x7x4x3x5xf16> @@ -5497,7 +5497,7 @@ func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<10x12x16x64xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> @@ -5530,7 +5530,7 @@ func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor< // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = array} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> @@ -5724,7 +5724,7 @@ func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor< // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor // CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: broadcast_dimensions = array // CHECK-SAME: -> tensor<10x12x16x64xbf16> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir index 09374ca8006a2f..673e6c9ffd329a 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization-no-chlo.mlir @@ -6,6 +6,6 @@ // CHECK-LABEL: allows_chlo func.func @allows_chlo(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir index c91c7c9da8ac77..e6623350380fcb 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/verify-tfxla-legalization.mlir @@ -31,7 +31,7 @@ func.func @invalid_mixed_mhlo() -> (tensor<8x64x128xcomplex> {mhlo.sharding func.func @fails_chlo(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { // expected-error @+1 {{Could not legalize op: chlo.broadcast_add}} - %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> func.return %0 : tensor<1x32x10x32xi32> } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 0f0b1182e50bb7..28a459ccff2eac 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -270,7 +270,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/util/quantization:uniform_quant_ops_params", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index b8818bb9d3824b..4d96bb24acf152 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -132,7 +132,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); EXPECT_EQ(tf2xla_fallback_count, 316); - EXPECT_EQ(non_categorized_count, 422); + EXPECT_EQ(non_categorized_count, 423); } // Just a counter test to see which ops have duplicate lowerings. This isn't a @@ -224,7 +224,7 @@ TEST_F(LegalizationOpConfigTest, MlirLoweringWithoutXlaKernel) { } } - EXPECT_EQ(mlir_without_xla_count, 14); + EXPECT_EQ(mlir_without_xla_count, 13); } } // namespace mhlo diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 70370bffc41f20..76056458079964 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -497,7 +498,7 @@ static void CreateWhile32(Location loc, int num_iterations, // Increment the loop induction variable by one. auto one = builder->create(loc, builder->getI32IntegerAttr(1)); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); + auto scalar_broadcast_dims = builder->getDenseI64ArrayAttr({}); auto plus_one = builder->create( loc, block->getArgument(0), one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. @@ -2172,7 +2173,7 @@ class ConvertFusedBatchNormGradBase non_feature_dims.push_back(i); } auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = @@ -2315,7 +2316,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { Value corrected_variance = rewriter.create( op.getLoc(), batch_variance.getType(), batch_variance, - factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); + factor_const_op, /*broadcast_dimensions=*/DenseI64ArrayAttr()); // Convert back to input type to stay aligned with expected output type // for TF op. @@ -2335,24 +2336,24 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // new_running_mean = alpha * old_mean + beta * batch_mean. auto alpha_mul_old_mean = rewriter.create( op.getLoc(), op.getMean().getType(), alpha, op.getMean(), - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); auto beta_mul_batch_mean = rewriter.create( op.getLoc(), batch_mean.getType(), beta, batch_mean, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); batch_mean = rewriter.create( op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); // new_running_variance = alpha * old_variance + beta * batch_variance. auto alpha_mul_old_variance = rewriter.create( op.getLoc(), op.getVariance().getType(), alpha, op.getVariance(), - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); auto beta_mul_batch_variance = rewriter.create( op.getLoc(), corrected_variance.getType(), beta, corrected_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); corrected_variance = rewriter.create( op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); + /*broadcast_dimensions=*/DenseI64ArrayAttr()); } if (std::is_same::value) { @@ -2522,7 +2523,7 @@ Operation *AvgPoolDivideByCount( // Divide `pooled` by window counts. Value divisor = GetScalarConstOfType(element_type, loc, window_count, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); result = rewriter.create( loc, pooled_type, pooled, divisor, scalar_broadcast_dims); } else { @@ -4091,7 +4092,7 @@ class GenericConvertReductionOp : public OpRewritePattern { Value divisor = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({}, reduce_element_type), divisor_tensor); - auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); result = rewriter.create(loc, result, divisor, broadcast_dims); } @@ -6103,7 +6104,7 @@ class ConvertXlaReduceScatterOp if (replica_group_size == 0) return failure(); auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size, &rewriter); - auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); result = rewriter.create( loc, result, divisor.getResult(), broadcast_dims); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc index d7937cce42ff24..54bd5812644488 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc @@ -176,7 +176,7 @@ LogicalResult ConvertAllReduce(OpBuilder& builder, int64_t channel_id, } auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size, &builder); - auto broadcast_dims = GetI64ElementsAttr({}, &builder); + auto broadcast_dims = builder.getDenseI64ArrayAttr({}); result = builder.create( loc, all_reduce.getResult(0), divisor.getResult(), broadcast_dims); } else if (final_op != "Id") { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 763e94734f6d01..3e8dd5b58ed2f1 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -16,6 +16,9 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect's communication // ops (TF/XLA) to the HLO dialect. +#include +#include +#include #include #include #include @@ -31,6 +34,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -69,6 +73,28 @@ class LegalizeTFCommunication void runOnOperation() override; }; +// A generator to serve out unique channel ids. +class ChannelIdGenerator { + public: + ChannelIdGenerator() = default; + ChannelIdGenerator(const ChannelIdGenerator&) = delete; + ChannelIdGenerator& operator=(const ChannelIdGenerator&) = delete; + ChannelIdGenerator(ChannelIdGenerator&&) = delete; + ChannelIdGenerator& operator=(ChannelIdGenerator&&) = delete; + int64_t operator++(int) { return next(); } + int64_t next() { return channel_id_.fetch_add(1, std::memory_order_relaxed); } + + private: + // All usage code expects positive int64_t values so we can't use uint64_t + // and will just have to limit ourselves to half the number space. + std::atomic channel_id_ = 1; +}; + +int64_t GetNextChannelId() { + static ChannelIdGenerator* channel_id = new ChannelIdGenerator(); + return channel_id->next(); +} + // Checks if an op is a TF/XLA communication op. bool IsCommunicationOp(Operation* op) { return isa( loc, token.getType(), operand, token, channel_handle, @@ -273,12 +299,12 @@ Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, } // Creates a `mhlo.recv` op for receiving a value. -Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, - Value result, StringRef key, size_t index, Value token, +Value CreateRecvOp(OpBuilder& builder, Location loc, Value result, + StringRef key, size_t index, Value token, StringRef host_handler_name, bool manual_sharding) { // type 3 == HOST_TO_DEVICE auto channel_handle = ChannelHandleAttr::get(builder.getContext(), - /*handle=*/channel_id++, + /*handle=*/GetNextChannelId(), /*type=*/3); auto result_type = result.getType(); SmallVector recv_result_type = {result_type, token.getType()}; @@ -315,7 +341,7 @@ Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef tokens, // ops per operand and result. Unique Channel IDs are assigned per transfer. // Sink tokens are created across all `mhlo.send` ops first and then by // all `mhlo.recv` ops. -Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, +Value RewriteHostComputeOp(OpBuilder& builder, TF::_XlaHostComputeMlirOp host_compute, Value token) { builder.setInsertionPoint(host_compute); @@ -325,7 +351,7 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, SmallVector send_tokens; for (auto operand : llvm::enumerate(host_compute.getInputs())) { auto send_token = CreateSendOp( - builder, channel_id, loc, operand.value(), host_compute.getSendKey(), + builder, loc, operand.value(), host_compute.getSendKey(), operand.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName, manual_sharding); send_tokens.push_back(send_token); @@ -335,9 +361,8 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, SmallVector recv_tokens; for (auto result : llvm::enumerate(host_compute.getOutputs())) { auto recv_token = CreateRecvOp( - builder, channel_id, loc, result.value(), host_compute.getRecvKey(), - result.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName, - manual_sharding); + builder, loc, result.value(), host_compute.getRecvKey(), result.index(), + token, xla::kXlaHostTransferTfRendezvousHandlerName, manual_sharding); recv_tokens.push_back(recv_token); } token = CreateSinkToken(builder, loc, recv_tokens, token); @@ -347,11 +372,11 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, } // Replaces `tf.XlaSendToHost` with a `mhlo.send`. -Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, - TF::XlaSendToHostOp send_to_host, Value token) { +Value RewriteSendToHostOp(OpBuilder& builder, TF::XlaSendToHostOp send_to_host, + Value token) { builder.setInsertionPoint(send_to_host); - token = CreateSendOp(builder, channel_id, send_to_host.getLoc(), - send_to_host.getInput(), send_to_host.getKey(), + token = CreateSendOp(builder, send_to_host.getLoc(), send_to_host.getInput(), + send_to_host.getKey(), /*index=*/0, token, xla::kXlaHostTransferTfRendezvousHandlerName, /*manual_sharding=*/false); @@ -361,10 +386,10 @@ Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, } // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`. -Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id, +Value RewriteRecvFromHostOp(OpBuilder& builder, TF::XlaRecvFromHostOp recv_from_host, Value token) { builder.setInsertionPoint(recv_from_host); - token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(), + token = CreateRecvOp(builder, recv_from_host.getLoc(), recv_from_host.getOutput(), recv_from_host.getKey(), /*index=*/0, token, xla::kXlaHostTransferTfRendezvousHandlerName, @@ -795,7 +820,7 @@ void RewriteFunctionTerminator(OpBuilder& builder, // rewritten to create a token or take in and return a token, depending on its // visibility and if there are any callers. LogicalResult RewriteFunction( - OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func, + OpBuilder& builder, ModuleOp module, FuncOp func, const llvm::SmallDenseMap& funcs, const llvm::SmallPtrSetImpl& control_flow_ops, const llvm::SmallPtrSetImpl& control_flow_blocks, bool is_clone) { @@ -832,11 +857,11 @@ LogicalResult RewriteFunction( Operation* next_op = curr_op->getNextNode(); if (auto host_compute = dyn_cast(curr_op)) { - token = RewriteHostComputeOp(builder, channel_id, host_compute, token); + token = RewriteHostComputeOp(builder, host_compute, token); } else if (auto send_to_host = dyn_cast(curr_op)) { - token = RewriteSendToHostOp(builder, channel_id, send_to_host, token); + token = RewriteSendToHostOp(builder, send_to_host, token); } else if (auto recv_from_host = dyn_cast(curr_op)) { - token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token); + token = RewriteRecvFromHostOp(builder, recv_from_host, token); } else if (auto call = dyn_cast(curr_op)) { // Only `mlir::func::CallOp` is supported as this requires knowing how to // rewrite arguments and results to a function. @@ -929,14 +954,11 @@ void LegalizeTFCommunication::runOnOperation() { if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite))) return signalPassFailure(); - // Module level counter to make sure Channel IDs are unique. - int64_t channel_id = 1; OpBuilder builder(&getContext()); for (const auto& func_and_name : funcs_to_rewrite) { const auto& func_to_rewrite = func_and_name.getSecond(); func::FuncOp func = func_to_rewrite.original; - if (failed(RewriteFunction(builder, channel_id, module, func, - funcs_to_rewrite, + if (failed(RewriteFunction(builder, module, func, funcs_to_rewrite, func_to_rewrite.control_flow_ops, func_to_rewrite.control_flow_blocks, /*is_clone=*/false))) @@ -949,8 +971,8 @@ void LegalizeTFCommunication::runOnOperation() { GetCommunicationControlFlowOps(clone, funcs_to_rewrite, clone_control_flow_ops, clone_control_flow_blocks); - if (failed(RewriteFunction(builder, channel_id, module, clone, - funcs_to_rewrite, clone_control_flow_ops, + if (failed(RewriteFunction(builder, module, clone, funcs_to_rewrite, + clone_control_flow_ops, clone_control_flow_blocks, /*is_clone=*/true))) llvm_unreachable( diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 0ee5d1dee5925d..108b1bf6e6bc86 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -75,7 +75,7 @@ def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), (CHLO_BroadcastCompareOp (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), - (NullDenseIntElementsAttr), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE))>; @@ -158,18 +158,18 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$l_cmp $l, (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$r_cmp $r, (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), - (NullDenseIntElementsAttr)), + (NullDenseI64ArrayAttr)), (CHLO_BroadcastSubOp $div, (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), - (NullDenseIntElementsAttr)), $div), + (NullDenseI64ArrayAttr)), $div), [(SignedIntTensor $l)]>; // FloorDiv of unsigned is just div. @@ -189,19 +189,19 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), (CHLO_BroadcastCompareOp (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"NE">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$r_cmp $r, (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), - (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, (BinBroadcastDimensions $rem, $r_zeros), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (BinBroadcastDimensions $r_cmp, $rem_cmp), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), - (NullDenseIntElementsAttr)), + (NullDenseI64ArrayAttr)), (CHLO_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem), [(TensorOf<[I8, I16, I32, I64, F16, F32, F64]> $l)]>; @@ -580,7 +580,7 @@ foreach Mapping = [ [TF_DigammaOp, CHLO_DigammaOp], [TF_ExpOp, MHLO_ExpOp], [TF_Expm1Op, MHLO_Expm1Op], - [TF_ErfOp, CHLO_ErfOp], + [TF_ErfOp, MHLO_ErfOp], [TF_ErfcOp, CHLO_ErfcOp], [TF_FloorOp, MHLO_FloorOp], [TF_ImagOp, MHLO_ImagOp], @@ -694,13 +694,13 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_BroadcastAddOp:$threshold (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), (MHLO_ConstantOp (GetScalarOfType<2> $features)), - (NullDenseIntElementsAttr) + (NullDenseI64ArrayAttr) ), (MHLO_SelectOp:$output (CHLO_BroadcastCompareOp $features, (MHLO_NegOp $threshold), - (NullDenseIntElementsAttr), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE) ), @@ -709,7 +709,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_BroadcastCompareOp $features, $threshold, - (NullDenseIntElementsAttr), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE) ), diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 122a9084771d88..2c49198be7bad8 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -40,7 +40,6 @@ limitations under the License. #include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/cpu/hlo_xla_runtime_pipeline.h" -#include "xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 10862b22c4f8d6..04cd4282e5c451 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -344,7 +344,6 @@ tf_python_pybind_extension( ":tfr", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/python/lib/core:pybind11_lib", - "//tensorflow/python/lib/core:pybind11_status", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", @@ -352,6 +351,7 @@ tf_python_pybind_extension( "@llvm-project//mlir:Parser", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", "@pybind11", ], ) diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc index 760ddab974c7fd..5d572f8278684b 100644 --- a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc +++ b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc @@ -15,24 +15,25 @@ limitations under the License. #include +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" #include "tensorflow/python/lib/core/pybind11_lib.h" -#include "tensorflow/python/lib/core/pybind11_status.h" PYBIND11_MODULE(tfr_wrapper, m) { m.def("verify", [](std::string input) { diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index a73fc10ae083c4..315d9fca646d1d 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -313,25 +313,21 @@ cc_library( ":tf_to_tfrt", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_type", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:errors", "@tf_runtime//:bef", "@tf_runtime//:core_runtime", - "@tf_runtime//:hostcontext", "@tf_runtime//:mlirtobef", - "@tf_runtime//:tensor", ], ) @@ -707,6 +703,7 @@ cc_library( cc_library( name = "backend_compiler", + srcs = ["backend_compiler.cc"], hdrs = ["backend_compiler.h"], deps = [ "//tensorflow/core/tfrt/runtime", diff --git a/third_party/xla/third_party/tsl/tsl/platform/gif.h b/tensorflow/compiler/mlir/tfrt/backend_compiler.cc similarity index 73% rename from third_party/xla/third_party/tsl/tsl/platform/gif.h rename to tensorflow/compiler/mlir/tfrt/backend_compiler.cc index 865b6f201e66fe..7c04c778fda8da 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/gif.h +++ b/tensorflow/compiler/mlir/tfrt/backend_compiler.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PLATFORM_GIF_H_ -#define TENSORFLOW_TSL_PLATFORM_GIF_H_ +#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" -#include "gif_lib.h" // from @gif +namespace tensorflow { -#endif // TENSORFLOW_TSL_PLATFORM_GIF_H_ +BackendCompiler::~BackendCompiler() = default; + +} diff --git a/tensorflow/compiler/mlir/tfrt/backend_compiler.h b/tensorflow/compiler/mlir/tfrt/backend_compiler.h index 827dc92bd72f2e..0e959f04f43554 100644 --- a/tensorflow/compiler/mlir/tfrt/backend_compiler.h +++ b/tensorflow/compiler/mlir/tfrt/backend_compiler.h @@ -17,13 +17,16 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TFRT_BACKEND_COMPILER_H_ #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "tensorflow/core/tfrt/runtime/runtime.h" namespace tensorflow { class BackendCompiler { public: - virtual ~BackendCompiler() = default; + virtual ~BackendCompiler(); + + virtual void GetDependentDialects(mlir::DialectRegistry& registry) const {} // Compile the `module` in TF dialect. The result module should be also in TF // dialect. diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td index 87ee9fc91ca57e..7fbc42ad3db93f 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td @@ -443,6 +443,8 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { `tf.IfrtLoadVariableOp` converts the tensor into an IFRT array based on device and sharding configuration specified in `VariableDeviceShardingConfigProto`. + + This op returns a scalar string tensor as a key for user to look for the loaded array. }]; let arguments = (ins @@ -450,6 +452,10 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { StrAttr:$device_sharding_config_proto_text, StrAttr:$name ); + + let results = (outs + TFTensorType:$array_key + ); } diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index b96eae576e4983..b15cc904b8039c 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -17,24 +17,24 @@ limitations under the License. #include -#include "absl/strings/str_split.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime -#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime -#include "tfrt/core_runtime/op_handler.h" // from @tf_runtime -#include "tfrt/host_context/host_context.h" // from @tf_runtime -#include "tfrt/tensor/dense_host_tensor_view.h" // from @tf_runtime namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h index 94b7f73fd73068..091e6642650b25 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h @@ -23,7 +23,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir index a23fafec92c028..dbb77732a3d6f6 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir @@ -1,17 +1,17 @@ -// RUN: tf-tfrt-opt -split-input-file -rewrite-cluster-to-ifrt-call=tpu-compile-metadata-debug %s | FileCheck %s +// RUN: tf-tfrt-opt -split-input-file -rewrite-cluster-to-ifrt-call %s | FileCheck %s // TODO(b/316226111): the printer may not guarantee the same order of fields. Rewrite the checks to be less sensitive to proto serialization formats. // ----- // Non-SPMD: one input and one output // // CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> { // CHECK-NEXT: "tf.IfrtCall"(%arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = []} // CHECK-SAME: (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>) +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>) // CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 1 " -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK: return module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1"], tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1704 : i32}} { @@ -33,13 +33,13 @@ func.func private @_func(%arg0: tensor<1x3xf32>) -> (tensor<1x3xf32>) { // // CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) { // CHECK-NEXT: "tf.IfrtCall"(%arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = []} // CHECK-SAME: (tensor<1x3xf32>) -> () // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>) +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>) // CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true " -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK: return module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1"], tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1704 : i32}} { @@ -60,17 +60,17 @@ func.func private @_func(%arg0: tensor<1x3xf32>) -> () { // CHECK-LABEL: func.func @serving_default(%arg0: tensor<3x1xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x1xf32> { // CHECK-NEXT: %0 = "tf.IfrtCall"(%arg1, %arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [] // CHECK-SAME: (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> // CHECK-NEXT: %1 = "tf.Identity"(%arg1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK-NEXT: %2 = "tf.IfrtCall"(%1, %arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID]] : i64, variable_arg_indices = [], variable_names = [] +// CHECK-SAME: {program_id = [[PROGRAM_ID]] : i64, variable_arg_indices = [] // CHECK-SAME: (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> // CHECK-NEXT: %3 = "tf.add"(%0, %2) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1) // CHECK: return @@ -97,12 +97,12 @@ func.func private @_func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> (ten // CHECK-LABEL: func.func @serving_default(%arg0: tensor<3x1xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x1xf32> { // CHECK-NEXT: %0 = "tf.IfrtCall"(%arg1, %arg0) -// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [], variable_names = []} +// CHECK-SAME: {program_id = [[PROGRAM_ID:.*]] : i64, variable_arg_indices = [] // CHECK-SAME: (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> // CHECK: return // -// CHECK: func.func private @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> -// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64, tpu_compile_metadata = +// CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> +// CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1) // CHECK: return diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir new file mode 100644 index 00000000000000..ba644948c6b06d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/sink_variable_as_named_array.mlir @@ -0,0 +1,44 @@ +// RUN: tf-tfrt-opt -split-input-file -sink-variable-as-named-array %s | FileCheck %s + +// ----- +// Basic test: all variables tensors are for devices and sinked as named ifrt arrays +// +// +// CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { +// CHECK-NEXT: [[HANDLE2:%.*]] = "tf.VarHandleOp" +// CHECK-NEXT: [[KEY:%.*]] = "tf.IfrtLoadVariable"([[HANDLE2]]) +// CHECK-SAME: device_sharding_config_proto_text = "sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } device_ids: 0 device_ids: 1 " +// CHECK-SAME: name = "__y" +// CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"([[KEY]], %arg0) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [0 : i32]}> +// CHECK-SAME: : (tensor, tensor<1x3xf32>) -> tensor<1x1xf32> +// CHECK-NEXT: return [[RES]] : tensor<1x1xf32> +// +module { + func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %2 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<3x1xf32> + %result = "tf.IfrtCall"(%2, %arg0) <{program_id = 6515870160938153680 : i64, variable_arg_indices = []}> { __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 3 } dim { size: 1 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } args { dtype: DT_FLOAT shape { dim { size: 3 } dim { size: 1 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true "} : (tensor<3x1xf32>, tensor<1x3xf32>) -> (tensor<1x1xf32>) + return %result : tensor<1x1xf32> + } +} + +// ----- +// Variable tensor for host can still be used. +// +// CHECK-LABEL: func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { +// CHECK: "tf.VarHandleOp" +// CHECK-NEXT: [[VARIABLE:%.*]] = "tf.ReadVariableOp" +// CHECK-NEXT: [[KEY:%.*]] = "tf.IfrtLoadVariable" +// CHECK-NEXT: "tf.MatMul"(%arg0, [[VARIABLE]]) +// CHECK-NEXT: [[RES:%.*]] = "tf.IfrtCall"(%arg0, [[KEY]]) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> +// CHECK-NEXT: return [[RES]] : tensor<1x1xf32> +// +module { + func.func @serving_default(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %2 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<3x1xf32> + %3 = "tf.MatMul"(%arg0, %2) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<1x1xf32> + %result = "tf.IfrtCall"(%arg0, %2) <{program_id = 6515870160938153680 : i64, variable_arg_indices = []}> { __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } args { dtype: DT_FLOAT shape { dim { size: 3 } dim { size: 1 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true "} : (tensor<1x3xf32>, tensor<3x1xf32>) -> (tensor<1x1xf32>) + return %result : tensor<1x1xf32> + } +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_merging.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_merging.mlir new file mode 100644 index 00000000000000..cd2ee741ba86bf --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_merging.mlir @@ -0,0 +1,21 @@ +// RUN: tf-tfrt-opt %s -tf-restore-merging | FileCheck %s + +// CHECK-LABEL: func @single_restore_group +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) +func.func @single_restore_group(%arg0: tensor) -> (tensor<*xf32>, tensor<*xi32>) { + %0 = "tf.Const"() {value = dense<"foo"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %1 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %2 = "tf.RestoreV2"(%arg0, %0, %1) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<*xf32> + + %3 = "tf.Const"() {value = dense<"bar"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %4 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %5 = "tf.RestoreV2"(%arg0, %3, %4) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<*xi32> + + // CHECK: %[[NAMES:.*]] = "tf.Const"() <{value = dense<["foo", "bar"]> : tensor<2x!tf_type.string>}> + // CHECK-NEXT: %[[SHAPES:.*]] = "tf.Const"() <{value = dense<""> : tensor<2x!tf_type.string>}> + // CHECK-NEXT: %[[TENSORS:.*]]:2 = "tf.RestoreV2"(%[[ARG0]], %[[NAMES]], %[[SHAPES]]) + // CHECK-SAME: -> (tensor<*xf32>, tensor<*xi32>) + + // CHECK: return %[[TENSORS]]#0, %[[TENSORS]]#1 + func.return %2, %5 : tensor<*xf32>, tensor<*xi32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_splitting.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_splitting.mlir new file mode 100644 index 00000000000000..1aafed888aa9b9 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_restore_splitting.mlir @@ -0,0 +1,18 @@ +// RUN: tf-tfrt-opt %s -tf-restore-splitting | FileCheck %s + +// CHECK-LABEL: func @single_restore +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) +func.func @single_restore(%arg0: tensor) -> (tensor<*xf32>, tensor<*xi32>) { + %0 = "tf.Const"() {value = dense<["foo", "bar"]> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string> + %1 = "tf.Const"() {value = dense<""> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string> + %2:2 = "tf.RestoreV2"(%arg0, %0, %1) : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<*xf32>, tensor<*xi32>) + + // CHECK: %[[FOO_NAME:.*]] = "tf.Const"() <{value = dense<"foo"> : tensor<1x!tf_type.string>}> + // CHECK: %[[FOO:.*]] = "tf.RestoreV2"(%[[ARG0]], %[[FOO_NAME]], {{.*}}) + + // CHECK: %[[BAR_NAME:.*]] = "tf.Const"() <{value = dense<"bar"> : tensor<1x!tf_type.string>}> + // CHECK: %[[BAR:.*]] = "tf.RestoreV2"(%[[ARG0]], %[[BAR_NAME]], {{.*}}) + + // CHECK: return %[[FOO]], %[[BAR]] + func.return %2#0, %2#1 : tensor<*xf32>, tensor<*xi32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir index 94a28091c7235f..eb2e0587364d6e 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir @@ -458,3 +458,21 @@ func.func @xla_func(%arg0: tensor<1x3xf32>) -> tensor<*xf32> attributes {tf.entr %2 = "tf.XlaLaunch"(%arg0, %1) {__op_key = 3: i32, _noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_0, operandSegmentSizes = array} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<*xf32> func.return %2 : tensor<*xf32> } + +// ----- + +// Test lowering of IfrtLoadVariableOp + +// CHECK-LABEL: func @ifrt_load_variable_test +func.func @ifrt_load_variable_test() -> () { + // CHECK: [[HANDLE:%.*]] = tf_mlrt.executeop() + // CHECK-SAME: VarHandleOp + %0 = "tf.VarHandleOp"() {__op_key = 1: i32, device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + // CHECK-NEXT: "tf_mlrt.ifrt_load_variable"([[HANDLE]]) + // CHECK-SAME: device_sharding_config_proto_text + // CHECK-SAME: name = "__variable" + %1 = "tf.IfrtLoadVariable"(%0) <{device_sharding_config_proto_text = "sharding { } device_ids: 0 device_ids: 1 ", name = "__variable"}> {__op_key = 2: i32, device = "/device:CPU:0"} : (tensor>>) -> (tensor) + // CHECK-NEXT: return + func.return +} + diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir index a2e2922cc7ee7f..a872b96a2fd6b4 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/decompose_resource_op.mlir @@ -8,8 +8,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: [[const_th:%.*]] = corert.const_dense_tensor // CHECK-NEXT: [[const:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[const_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0"} // CHECK-NEXT: [[out_chain:%.*]], [[value:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(0) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"({{.*}}) -// CHECK-NEXT: [[res:%.*]] = tfrt_fallback_async.executeop key(1) cost({{.*}}) device("/device:CPU:0") "tf.GatherV2"([[value]], {{.*}}, [[const]]) -// CHECK-NEXT: [[res_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[res:%.*]] {device = "/device:CPU:0"} +// CHECK-NEXT: [[res:%.*]] = tfrt_fallback_async.executeop key(1) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.GatherV2"([[value]], {{.*}}, [[const]]) +// CHECK-NEXT: [[res_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[res]] {device = "/job:localhost/replica:0/task:0/device:CPU:0"} // CHECK-NEXT: tfrt.return [[out_chain]], [[res_th]] : !tfrt.chain, !corert.tensorhandle func.func @gather(%indices: tensor, %resource: tensor<*x!tf_type.resource>) -> tensor<*xi32> { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index c3a2a757f3776b..d120d1fe4f5005 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -15,6 +15,8 @@ package_group( "//tensorflow/core/tfrt/saved_model/tests/...", ] + if_google([ "//learning/brain/tfrt/cpp_tests/...", + "//learning/pathways/serving/runtime/...", + "//learning/pathways/serving/tests/...", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", ]), @@ -51,7 +53,10 @@ cc_library( name = "tf_ifrt_passes", srcs = [ "rewrite_cluster_to_ifrt_call.cc", + "sink_variable_as_named_array.cc", "tf_ifrt_passes.cc", + "tf_restore_merging.cc", + "tf_restore_splitting.cc", ], hdrs = [ "tf_ifrt_passes.h", @@ -66,22 +71,31 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:tpu_metadata_utils", + "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", "//tensorflow/core:framework", + "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:random", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/service:computation_placer_hdr", ], ) @@ -104,6 +118,7 @@ cc_library( "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:topology_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/core/tpu/kernels/xla:host_compute_ops", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -115,7 +130,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", @@ -141,7 +155,6 @@ tf_cc_test( ":tf2hlo", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_helpers", - "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core/platform:resource_loader", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 412e86ef4e39f3..b0815b8f5c0272 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -92,6 +92,7 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name, model_name, entry_function_name.str(), *std::move(submodule), ifrt_model_context.GetClient(), &ifrt_model_context.GetThreadPoolDevice(), + &ifrt_model_context.GetLoadedVariableRegistry(), ifrt_model_context.GetShapeRepresentationFn()); // Register the Ifrt program to `ServingExecutableRegistry` so that @@ -159,7 +160,7 @@ absl::Status IfrtBackendCompiler::CompileTensorflow( // Use bridge for cluster formation. TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tensorflow::tf2xla::v2::DeviceType::XLA_TPU_JIT, + module, /*is_supported_by_replicated_brige*/ true, /*is_in_fallback_enabled_mode=*/false)); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc index 7ade31aa2a4506..1b599fd2c33c2b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include -#include + +// Enable definition of Eigen::ThreadPoolDevice instead of just declaration. +#define EIGEN_USE_THREADS #include #include "absl/strings/str_cat.h" @@ -79,8 +81,6 @@ TEST(IfrtBackendCompilerTest, Basic) { xla::ifrt::test_util::GetClient()); Eigen::ThreadPoolDevice thread_pool_device = GetThreadPoolDevice(); - IfrtModelContext model_context(client, &thread_pool_device); - std::unique_ptr runtime = tensorflow::tfrt_stub::DefaultTfrtRuntime(/*num_threads=*/1); tensorflow::tfrt_stub::GraphExecutionOptions graph_execution_options( @@ -90,7 +90,7 @@ TEST(IfrtBackendCompilerTest, Basic) { &graph_execution_options, /*export_dir=*/"", &resource_context); runtime_context.resource_context().CreateResource( - "IfrtModelContext", std::move(model_context)); + "IfrtModelContext", client, &thread_pool_device); IfrtBackendCompiler compiler; TF_ASSERT_OK(compiler.CompileTensorflow(runtime_context, mlir_module.get())); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h index 4cd5bf2cbfc3bb..3e4971826c67b6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h @@ -21,14 +21,19 @@ limitations under the License. namespace tensorflow { namespace ifrt_serving { -// Attribute name of a serialized TpuCompileMetadataProto. This is backward -// compatible. -inline constexpr absl::string_view kMetadataAttrName = "tpu_compile_metadata"; // Attribute name of a text TpuCompileMetadataProto. Note that the text proto is -// not backward compatible and only used for debug. +// not backward compatible and shall not be serialized. inline constexpr absl::string_view kMetadataTextAttrName = "__tpu_compile_metadata_text"; +// Name of a variable as loaded IFRT array . +inline constexpr absl::string_view kVariableArrayNameAttr = + "__variable_array_name"; + +// Attribute of a text `VariableDeviceShardingConfigProto`. +inline constexpr absl::string_view kVariableShardingConfigTextAttr = + "__variable_sharding_config_text"; + } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td index a79e91f0422983..c725aa85b0157b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td @@ -26,11 +26,52 @@ def RewriteClusterToIfrtCallPass: Pass<"rewrite-cluster-to-ifrt-call", "mlir::Mo let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"]; let constructor = "CreateRewriteClusterToIfrtCallPass()"; + } - let options = [ - Option<"tpu_compile_metadata_debug_", "tpu-compile-metadata-debug", "bool", "false", - "if enabled, output compile metadata as readable string in " - "an extra __tpu_compile_metadata_debug attribute for debug">, - ]; +def SinkVariableAsNamedArrayPass: Pass<"sink-variable-as-named-array", "mlir::ModuleOp"> { + let summary = "Sink variable tensor for tpu device as named IFRT array for tf.IfrtCall"; + let description = [{ + This pass sinks variable tensor argument to `tf.IfrtCall` as variable_arg_indices + and variable_names attributes and also lowers `tf.ReadVariableOp` to + `tf.IfrtLoadVariableOp`. + + The runtime ensures that `tf.IfrtCall` kernel can bind the IFRT array by + its name as input to the TPU program. + + }]; + + let constructor = "CreateSinkVariableAsNamedArrayPass()"; } + +def TfRestoreSplittingPass + : Pass<"tf-restore-splitting", "mlir::func::FuncOp"> { + let summary = "Splits `tf.RestoreV2` ops"; + + let description = [{ + This pass splits each `tf.RestoreV2` op so that one restore op handles one + variable only. This pass can split restore ops only if the tensor names and + the shape/slices arguments are constants, which is usually the case. + + Splitting monolithic restore ops into per-tensor restore ops makes it easier + to shard SavedModel initialization across multiple clusters. + }]; + + let constructor = "CreateTfRestoreSplittingPass()"; +} + +def TfRestoreMergingPass : Pass<"tf-restore-merging", "mlir::func::FuncOp"> { + let summary = "Merges `tf.RestoreV2` ops"; + + let description = [{ + This pass merges multiple `tf.RestoreV2` ops into one `tf.RestoreV2` op + using variadic results. The current implementation merges restore ops only + if they have the same `prefix` and have constant tensor names and + shape/slice arguments. + + This pass is run in order to undo `tf-restore-splitting` after cluster + formation and reduce the op dispatch overhead. + }]; + + let constructor = "CreateTfRestoreMergingPass()"; +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc index 89428db2dd148c..b8bd4685919d4c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc @@ -190,9 +190,15 @@ class RewriteClusterToIfrtCallPass return signalPassFailure(); } + auto metadata_attr = + ifrt_program->getAttrOfType(kMetadataTextAttrName); + if (!metadata_attr) { + return signalPassFailure(); + } + ifrt_call_op->setAttr(kMetadataTextAttrName, metadata_attr); + // TODO(b/304839793): populate variable names after adding a variable // hoisting pass. - ifrt_call_op.setVariableNamesAttr(builder.getArrayAttr({})); ifrt_call_op.setVariableArgIndicesAttr(builder.getI32ArrayAttr({})); ifrt_call_op.setProgramId(program_id); @@ -214,25 +220,24 @@ class RewriteClusterToIfrtCallPass if (mlir::failed(GetTpuCompileMetadata(cluster_func, devices, &metadata))) { return signalPassFailure(); } + std::string serialized_metadata; + tsl::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.PrintToString(metadata, &serialized_metadata); - cloned_ifrt_program->setAttr( - kMetadataAttrName, builder.getStringAttr(metadata.SerializeAsString())); + cloned_ifrt_program->setAttr(kMetadataTextAttrName, + builder.getStringAttr(serialized_metadata)); - if (tpu_compile_metadata_debug_) { - std::string serialized_metadata; - tsl::protobuf::TextFormat::Printer printer; - printer.SetSingleLineMode(true); - printer.PrintToString(metadata, &serialized_metadata); - - cloned_ifrt_program->setAttr(kMetadataTextAttrName, - builder.getStringAttr(serialized_metadata)); - } cloned_ifrt_program.setName(ifrt_program_name); int64_t program_id = NewProgramId(); cloned_ifrt_program->setAttr("tfrt_ifrt_serving.program_id", builder.getI64IntegerAttr(program_id)); + // Make clonet ifrt program public so that it does not get dropped by + // inliner. + cloned_ifrt_program.setPublic(); + builder.setInsertionPoint(cluster_func); mlir::TF::IfrtCallOp ifrt_call_op = builder.create( @@ -241,9 +246,12 @@ class RewriteClusterToIfrtCallPass // TODO(b/304839793): populate variable names after adding a variable // hoisting pass. - ifrt_call_op.setVariableNamesAttr(builder.getArrayAttr({})); ifrt_call_op.setVariableArgIndicesAttr(builder.getI32ArrayAttr({})); ifrt_call_op.setProgramId(program_id); + // Additionally attach tpu_compile_metadata to IfrtCallOp. Some subsequent + // pass such as SinkVariableAsNamedArrayPass relies on this attribute. + ifrt_call_op->setAttr(kMetadataTextAttrName, + builder.getStringAttr(serialized_metadata)); cluster_func->replaceAllUsesWith(ifrt_call_op.getResults()); cluster_func->erase(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc new file mode 100644 index 00000000000000..450c441ef4e3f7 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/sink_variable_as_named_array.cc @@ -0,0 +1,359 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h" +#include "xla/service/computation_placer.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_SINKVARIABLEASNAMEDARRAYPASS +#define GEN_PASS_DECL_SINKVARIABLEASNAMEDARRAYPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class SinkVariableAsNamedArrayPass + : public impl::SinkVariableAsNamedArrayPassBase< + SinkVariableAsNamedArrayPass> { + public: + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + mlir::OpBuilder builder(&getContext()); + + absl::flat_hash_map variable_config_by_name; + llvm::SmallDenseMap + ifrt_call_argument_configs; + + // First, we backtrack from IFRT call to collect variable tensors that needs + // to converted to loaded ifrt arrays and their associated information such + // as their name and defining ops. + std::vector ifrt_call_ops; + module.walk([&ifrt_call_ops](mlir::TF::IfrtCallOp call) { + ifrt_call_ops.push_back(call); + }); + for (const auto& call : ifrt_call_ops) { + if (mlir::failed(CollectVariablesUsedByDevice( + call, variable_config_by_name, ifrt_call_argument_configs))) { + return signalPassFailure(); + } + } + + // Rewrite ReadVariableOp with IfrtLoadVariableOp + llvm::SmallDenseMap + read_to_load; + for (auto& [name, variable_config] : variable_config_by_name) { + for (auto& read_variable_op : variable_config.read_variable_op) { + builder.setInsertionPointAfter(read_variable_op); + // TODO(b/319045348): consider use resource alias analysis for this. + auto var_handle = GetDefiningOp( + read_variable_op.getResource()); + + if (!var_handle) { + read_variable_op->emitError( + "ReadVariableOp has no defining VarHandleOp."); + return signalPassFailure(); + } + + auto load_variable_op = builder.create( + read_variable_op->getLoc(), + mlir::RankedTensorType::get( + {}, builder.getType()), + var_handle.getResult(), + builder.getStringAttr(variable_config.device_sharding_config), + builder.getStringAttr(name)); + read_to_load[read_variable_op] = load_variable_op; + } + } + + // Rewrite ifrt call: variable tensors are sunk as attribute. + // The runtime guarantees the binding of corresponding loaded ifrt array + // based on attributes. + for (auto& call : ifrt_call_ops) { + if (!call.getVariableArgIndicesAttr().empty()) { + call->emitError() << "Expect empty " + << call.getVariableArgIndicesAttrName().str() + << " attributes, but got " + << call.getVariableArgIndicesAttr().size() + << " elements"; + return signalPassFailure(); + } + if (call->getOpOperands().size() != + ifrt_call_argument_configs[call].size()) { + call->emitError() << "IfrtCallOp got " << call->getOpOperands().size() + << " operands, but expects " + << ifrt_call_argument_configs[call].size(); + return signalPassFailure(); + } + llvm::SmallVector variable_arg_indices; + llvm::SmallVector variable_arg_names; + llvm::SmallVector updated_args; + + for (const auto& [arg_idx, arg] : + llvm::enumerate(ifrt_call_argument_configs[call])) { + if (arg.is_variable) { + variable_arg_names.push_back( + builder.getStringAttr(arg.variable_name)); + variable_arg_indices.push_back(arg_idx); + // Variable use the key from IfrtLoadVariable. + updated_args.push_back( + read_to_load[arg.read_variable_op].getResult()); + } else { + // non variable + updated_args.push_back(call->getOperand(arg_idx)); + } + } + + builder.setInsertionPointAfter(call); + auto updated_ifrt_call = builder.create( + call->getLoc(), call.getResultTypes(), updated_args); + + updated_ifrt_call->setAttrs(call->getAttrs()); + // Update variable_arg_indices attribute. + updated_ifrt_call.setVariableArgIndicesAttr( + builder.getI32ArrayAttr(variable_arg_indices)); + + call.replaceAllUsesWith(updated_ifrt_call); + call.erase(); + } + + // Delete all ReadVariableOps that are not used. + for (auto& [name, variable_config] : variable_config_by_name) { + for (auto& read_variable_op : variable_config.read_variable_op) { + if (read_variable_op.use_empty()) { + read_variable_op.erase(); + } + } + } + } + + private: + struct VariableConfig { + // VariableDeviceShardingConfig text proto. + std::string device_sharding_config; + // All ReadVariableOps that returns this named variable. + std::vector read_variable_op; + }; + struct IfrtArgConfig { + bool is_variable; + std::string variable_name; + mlir::TF::ReadVariableOp read_variable_op; + }; + using IfrtArgConfigList = llvm::SmallVector; + + // Find defining ReadVariableOps and also build argument configuration map of + // a IfrtCallOp. + mlir::LogicalResult CollectVariablesUsedByDevice( + mlir::TF::IfrtCallOp call, + absl::flat_hash_map& variable_config_by_name, + llvm::SmallDenseMap& + ifrt_call_argument_configs) { + IfrtArgConfigList& args = ifrt_call_argument_configs[call]; + + tensorflow::tpu::TPUCompileMetadataProto metadata; + + // TODO(b/319045348): remove the usage kMetadataAttrName. + auto metadata_attr = + call->getAttrOfType(kMetadataTextAttrName); + if (metadata_attr && !metadata_attr.empty()) { + if (!tensorflow::protobuf::TextFormat::ParseFromString( + metadata_attr.getValue().str(), &metadata)) { + return call.emitError() + << "Failed to parse TPUCompileMetadataProto from attr :" + << metadata_attr.getValue().str(); + } + } else { + return call.emitError() + << "Failed to Get TPUCompileMetadataProto from attr"; + } + + for (const auto& [arg_idx, input] : llvm::enumerate(call->getOperands())) { + // Assuming the nested function calls are inlined. + if (auto read_variable_op = + GetDefiningOp(input)) { + mlir::FailureOr variable_tensor_name = + GetVariableTensorName(read_variable_op); + + if (mlir::failed(variable_tensor_name)) { + return mlir::failure(); + } + + absl::StatusOr device_sharding_config = + GetVariableShardingConfig(metadata, arg_idx); + if (!device_sharding_config.ok()) { + return call->emitError() + << "Fail to get device sharding config for argument index " + << arg_idx; + } + VariableConfig& variable_config = + variable_config_by_name[*variable_tensor_name]; + if (!variable_config.read_variable_op.empty()) { + if (variable_config.device_sharding_config != + *device_sharding_config) { + return call->emitError() + << "A variable tensor has different sharding config: " + << variable_config.device_sharding_config << " vs " + << *device_sharding_config; + } + } else { + variable_config.device_sharding_config = *device_sharding_config; + } + + variable_config.read_variable_op.push_back(read_variable_op); + args.push_back({.is_variable = true, + .variable_name = *variable_tensor_name, + .read_variable_op = read_variable_op}); + } else { + args.push_back({.is_variable = false}); + } + } + + return mlir::success(); + } + + // The returned variable tensor name is used both as an internal hash key, + // and as the binding name between the tensor and the array in the + // runtime. + std::string GetVariableTensorName(mlir::TF::VarHandleOp var_handle) { + return absl::StrCat(absl::string_view(var_handle.getContainer()), "__", + absl::string_view(var_handle.getSharedName())); + } + + mlir::FailureOr GetVariableTensorName( + mlir::TF::ReadVariableOp read_variable_op) { + mlir::Value variable_definition = read_variable_op.getResource(); + auto var_handle = GetDefiningOp(variable_definition); + + if (!var_handle) { + return read_variable_op->emitError("ReadVariableOp has no defining op."); + } + + return GetVariableTensorName(var_handle); + } + + absl::StatusOr GetVariableShardingConfig( + const tensorflow::tpu::TPUCompileMetadataProto& metadata, int arg_idx) { + tensorflow::ifrt_serving::VariableDeviceShardingConfigProto + device_sharding_config; + std::vector device_ids; + + if (metadata.has_device_assignment()) { + absl::StatusOr> da = + xla::DeviceAssignment::Deserialize(metadata.device_assignment()); + + if (!da.ok()) { + return da.status(); + } + if (metadata.num_replicas() != (*da)->replica_count() || + metadata.num_cores_per_replica() != (*da)->computation_count()) { + return absl::FailedPreconditionError(absl::StrCat( + "Device assignment has different replica count: ", + metadata.num_replicas(), " vs ", (*da)->replica_count(), + " or computation count: ", metadata.num_cores_per_replica(), " vs ", + (*da)->computation_count(), ".")); + } + + device_ids.reserve(metadata.num_replicas() * + metadata.num_cores_per_replica()); + for (int i = 0; i < (*da)->replica_count(); ++i) { + for (int j = 0; j < (*da)->computation_count(); ++j) { + device_ids.push_back((**da)(i, j)); + } + } + } else { + // Default use first N devices. + device_ids.resize(metadata.num_replicas() * + metadata.num_cores_per_replica()); + std::iota(device_ids.begin(), device_ids.end(), 0); + } + + device_sharding_config.mutable_device_ids()->Assign(device_ids.begin(), + device_ids.end()); + + if (metadata.args_size() > 0) { + *device_sharding_config.mutable_sharding() = + metadata.args(arg_idx).sharding(); + } + + std::string proto_text; + tsl::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.PrintToString(device_sharding_config, &proto_text); + + return proto_text; + } + + template + OpT GetDefiningOp(const mlir::Value& value) { + mlir::Operation* op = value.getDefiningOp(); + + while (op && !llvm::isa(op)) { + if (llvm::isa(op)) { + op = op->getOperand(0).getDefiningOp(); + } else { + return nullptr; + } + } + + if (op != nullptr) { + return llvm::dyn_cast(op); + } else { + return nullptr; + } + } +}; + +} // namespace + +std::unique_ptr> +CreateSinkVariableAsNamedArrayPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/xla_call_host_callback.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/xla_call_host_callback.mlir new file mode 100644 index 00000000000000..4eff0866ba7a66 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/xla_call_host_callback.mlir @@ -0,0 +1,23 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1758 : i32}} { + + func.func private @callee(%arg0: tensor, %arg1: tensor<*xi32>) { + "tf.XlaHostCompute"(%arg0, %arg1) <{ancestors = [], key = "@test_callee", recv_key = "", send_key = "", shapes = []}> {_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"]} : (tensor, tensor<*xi32>) -> () + return + } + + // The mlir module in XlaCallModule is serialized from: + // + // func.func private @_stablehlo_main_0(%arg0: tensor, %arg1: tensor<*xi32>) -> () attributes {_from_xla_call_module} { + // stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @callee}} : (tensor, tensor<*xi32>) -> () + // return + // } + // + // func.func @main(%arg0: tensor, %arg1: tensor<*xi32>) -> () { + // "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape], dim_args_spec = [], _entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = { mhlo.num_partitions = 1 }, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<*xi32>) -> () + // func.return + // } + func.func @main(%arg0: tensor, %arg1: tensor<*xi32>) attributes {tfrt_ifrt_serving.program_id = -2372940092539171444 : i64, __tpu_compile_metadata_text = "args { dtype: DT_INT32 kind: PARAMETER sharding { } } args { dtype: DT_INT32 kind: PARAMETER sharding { } } num_replicas: 1 num_cores_per_replica: 1 use_spmd_for_xla_partitioning: true compile_options { }"} { + "tf.XlaCallModule"(%arg0, %arg1) <{Sout = [], dim_args_spec = [], function_list = [@callee], module = "ML\EFR\0DStableHLO_v0.17.6\00\01\19\05\01\05\09\01\03\0B\03\07\0F\13\17\03M-\0D\01\19\0B\13\0B\0F\13\13\13\13\13\0B\13\13\03\15\0B\0B\0B\0B\13\0B\0F\0B\0B\0B\01\03\0F\03\0B3\07\0B\17\07\02\B1\05\0D\03\03\05\07\05\0F\11\01\05\17\01A\0B\17\01!\07\17\01!Q\17\01!}\03\03\13!\05\11\17\01#\0B\17\01%\0B\03\01\1D\13#\09\1D\15\0D\03#%\1D\17\13\0B\01\0B\05\1D\19\05\03\01\02\04)\03\00\FF\FF\FF\FF\FF\FF\FF\FF\05\1B3\05\11\05\03\07\01\1D\04O\05\01Q\09\03\01\07\04=\03\01\05\03P\0B\03\07\04)\03\05\0B\05\07\0D\0F\0F\00\05E\15\11\05\05\01\03\07\00\17\06\03\01\05\01\00j\03\1B)\1B\0B\03%)\95\15\1F\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00custom_call_v1\00return_v1\00experimental/users/deqiangc/mira/testdata/xla_call_module_serialized.mlir\00mhlo.num_partitions\00tf.backend_config\00\00main\00called_index\00tf.call_tf_function\00\08'\07\05\01\01\0B\19\1D\19\1F\1B\11'\1B)\19+\19\19\19", platforms = [], version = 5 : i64}> : (tensor, tensor<*xi32>) -> () + return + } +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index 96794df1b7e34f..3cecc9c90ab4e2 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -59,7 +59,6 @@ limitations under the License. #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" namespace tensorflow { @@ -68,37 +67,27 @@ namespace { static constexpr absl::string_view kEntryFuncName = "main"; absl::StatusOr GetCompileMetadata( - mlir::func::FuncOp op, absl::Span inputs, + mlir::func::FuncOp op, absl::Span inputs, const xla::ifrt::Client& ifrt_client) { tensorflow::tpu::TPUCompileMetadataProto metadata; - auto metadata_attr = op->getAttrOfType(kMetadataAttrName); auto metadata_text_attr = op->getAttrOfType(kMetadataTextAttrName); - if (metadata_attr && !metadata_attr.getValue().empty()) { - // tpu_compile_metadata takes priority if exists. - VLOG(1) << "Parsing from attribute " << kMetadataAttrName << " : " - << metadata_attr.getValue().str(); - if (!metadata.ParseFromString(metadata_attr.getValue())) { - return absl::InternalError( - absl::StrCat("Failed to parse tpu_compile_metadata attribute:", - metadata_attr.getValue().str())); - } - } else if (metadata_text_attr && !metadata_text_attr.getValue().empty()) { + if (metadata_text_attr && !metadata_text_attr.getValue().empty()) { // Try __tpu_compile_metadata_text attribute. This only for debugging // purpose. VLOG(1) << "Parsing from attribute " << kMetadataTextAttrName << metadata_text_attr.getValue().str(); if (!tsl::protobuf::TextFormat::ParseFromString( - metadata_text_attr.getValue(), &metadata)) { + metadata_text_attr.getValue().str(), &metadata)) { return absl::InvalidArgumentError(absl::StrCat( "Attribute ", kMetadataTextAttrName, ":", metadata_text_attr.getValue().str(), " cannot be parsed")); } } else { - return absl::InvalidArgumentError(absl::StrCat( - "Missing ", kMetadataAttrName, " and ", kMetadataTextAttrName)); + return absl::InvalidArgumentError( + absl::StrCat("Missing ", kMetadataTextAttrName)); } VLOG(3) << "TpuCompileMetadata before shape is populated " << metadata; @@ -126,14 +115,14 @@ absl::StatusOr GetCompileMetadata( "Only support PARAMETER, but got ", metadata.args(i).kind())); } - if (metadata.args(i).dtype() != inputs[i].dtype()) { + if (metadata.args(i).dtype() != inputs[i].dtype) { return absl::InternalError(absl::StrCat("Dtype mismatched! Expected ", metadata.args(i).dtype(), " got ", - inputs[i].dtype())); + inputs[i].dtype)); } // Update shape. - *metadata.mutable_args(i)->mutable_shape() = inputs[i].shape().AsProto(); + *metadata.mutable_args(i)->mutable_shape() = inputs[i].shape.AsProto(); } // Create a default device assignment if one is not given by the model. @@ -154,7 +143,7 @@ absl::StatusOr GetCompileMetadata( } // namespace absl::StatusOr CompileTfToHlo( - mlir::ModuleOp module, absl::Span inputs, + mlir::ModuleOp module, absl::Span inputs, absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) { if (VLOG_IS_ON(1)) { @@ -192,7 +181,7 @@ absl::StatusOr CompileTfToHlo( std::vector arg_shapes; for (const auto& input : inputs) { - arg_shapes.push_back(input.shape()); + arg_shapes.push_back(input.shape); } bool use_tuple_args = false; @@ -222,6 +211,7 @@ absl::StatusOr CompileTfToHlo( Tf2HloResult result; result.mlir_hlo_module = xla::llvm_ir::CreateMlirModuleOp(module->getLoc()); result.compile_metadata = std::move(compile_metadata); + result.host_compute_metadata = compilation_result.host_compute_metadata; TF_RETURN_IF_ERROR(xla::ConvertHloToMlirHlo( *result.mlir_hlo_module, &compilation_result.computation->proto())); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index f170a5fc5d9265..74fa271401f547 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -24,20 +24,28 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" namespace tensorflow { namespace ifrt_serving { +struct DtypeAndShape { + tensorflow::DataType dtype; + tensorflow::TensorShape shape; +}; + struct Tf2HloResult { mlir::OwningOpRef mlir_hlo_module; tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + tf2xla::HostComputeMetadata host_compute_metadata; }; // A class that convert tf module to hlo // TODO(b/304839793): provide wrap persistent compilation cache. absl::StatusOr CompileTfToHlo( - mlir::ModuleOp module, absl::Span inputs, + mlir::ModuleOp module, absl::Span inputs, absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 89f353c68eafdd..7ee1c450426b20 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -35,8 +35,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tsl/lib/core/status_test_util.h" @@ -124,13 +122,12 @@ TEST(Tf2HloTest, Tuple) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - std::vector tensors; - tensorflow::Tensor x(DT_FLOAT, tensorflow::TensorShape({1, 3})); - tensorflow::Tensor y(DT_FLOAT, tensorflow::TensorShape({3, 1})); - tensors.push_back(x); - tensors.push_back(y); - auto result = CompileTfToHlo(mlir_module.get(), tensors, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {1, 3}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {3, 1}}); + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); TF_ASSERT_OK(result.status()); } @@ -158,12 +155,11 @@ TEST(Tf2HloTest, Spmd) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - std::vector tensors; - tensorflow::Tensor x(DT_FLOAT, tensorflow::TensorShape({4, 64})); - tensors.push_back(x); - - auto result = CompileTfToHlo(mlir_module.get(), tensors, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {4, 64}}); + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -227,16 +223,13 @@ TEST(Tf2HloTest, UsingDefaultDeviceAssignment) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, xla::ifrt::test_util::GetClient()); - std::vector tensors; - tensorflow::Tensor x(DT_FLOAT, tensorflow::TensorShape({4, 64})); - tensorflow::Tensor y(DT_FLOAT, tensorflow::TensorShape({64, 10})); - tensorflow::Tensor z(DT_FLOAT, tensorflow::TensorShape({1, 4})); - tensors.push_back(x); - tensors.push_back(y); - tensors.push_back(z); - - auto result = CompileTfToHlo(mlir_module.get(), tensors, "main", *client, - tensorflow::IdentityShapeRepresentationFn()); + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {4, 64}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {64, 10}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {1, 4}}); + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -302,6 +295,47 @@ TEST(Tf2HloTest, UsingDefaultDeviceAssignment) { EXPECT_THAT(result->compile_metadata, EqualsProto(expected_compile_metadata)); } +// Multiple input and multiple out. +TEST(Tf2HloTest, XlaCallHostCallback) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/xla_call_host_callback.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, + mlir::ParserConfig(&context)); + + ASSERT_TRUE(mlir_module); + ASSERT_TRUE(mlir_module.get() != nullptr); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + + auto result = + CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", *client, + tensorflow::IdentityShapeRepresentationFn()); + + TF_ASSERT_OK(result.status()); + + ASSERT_EQ((*result).host_compute_metadata.device_to_host().size(), 1); + ASSERT_EQ( + (*result).host_compute_metadata.device_to_host().begin()->metadata_size(), + 2); + ASSERT_EQ((*result).host_compute_metadata.host_to_device().size(), 0); +} + } // namespace } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc index 47fe04925d1881..53bd55cc0d2799 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc @@ -27,10 +27,12 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" #include "tensorflow/core/util/debug_data_dumper.h" namespace tensorflow { @@ -69,6 +71,20 @@ void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm, mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); pm.addPass(CreateRewriteClusterToIfrtCallPass()); + + // Sink VarHandle with ReadVariableOp: subsequent SinkVariableAsNamedArrayPass + // rely on the co-existence of VarHandle and ReadVariable in the same + // function. + // First, we inline all the function calls. This will sink VarHandle + // with ReadVariable in most cases. Then SinkInvariantOpsPass will sink + // VarHandle to a few special Ops that inliner does not handle. + // TODO(b/319045348): the bridge before this pipeline already does some + // inlining. Consider removing this inliner. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(::tensorflow::CreateSinkInInvariantOpsPass()); + + // Sink variable tensor as named array in IFRT. + pm.addPass(CreateSinkVariableAsNamedArrayPass()); } } // namespace @@ -111,9 +127,7 @@ absl::Status RunClusterToIfrtRuntimeOpsPassPipeline( } // Register all IfrtPass -void RegisterTfIfrtPasses() { - mlir::registerPass([]() { return CreateRewriteClusterToIfrtCallPass(); }); -} +void RegisterTfIfrtPasses() { registerTfrtIfrtServingPasses(); } } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h index b04c61eb5c76d8..084b170fa5b9a4 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -30,6 +31,19 @@ namespace ifrt_serving { std::unique_ptr> CreateRewriteClusterToIfrtCallPass(); +// Creates a pass that sinks variable tensor argument to `tf.IfrtCall` as named +// arrays and lowers `tf.ReadVariableOp` to `tf.IfrtLoadVariableOp`. +std::unique_ptr> +CreateSinkVariableAsNamedArrayPass(); + +// Creates a pass that splits `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestoreSplittingPass(); + +// Creates a pass that merges `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestoreMergingPass(); + #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_merging.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_merging.cc new file mode 100644 index 00000000000000..5220824d3f716a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_merging.cc @@ -0,0 +1,164 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFRESTOREMERGINGPASS +#define GEN_PASS_DECL_TFRESTOREMERGINGPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class TfRestoreMergingPass + : public impl::TfRestoreMergingPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + for (mlir::Block& block : func) { + // Group `tf.RestoreV2` ops by prefixes and merge each group. + llvm::SmallDenseMap> + restore_groups; + for (auto restore : block.getOps()) { + restore_groups[restore.getPrefix()].push_back(restore); + } + for (const auto& restores : llvm::make_second_range(restore_groups)) { + if (mlir::failed(MergeRestores(restores))) { + return signalPassFailure(); + } + } + } + } + + private: + mlir::DenseStringElementsAttr GetStringTensorAttr( + llvm::ArrayRef values) { + const int size = values.size(); + const auto type = mlir::RankedTensorType::get( + {size}, mlir::TF::StringType::get(&getContext())); + return mlir::DenseStringElementsAttr::get(type, values); + } + + // Merges `tf.RestoreV2` ops with the same prefix. Ignores restore ops with + // non-constant `tensor_names` and/or `shape_and_slices`. + mlir::LogicalResult MergeRestores( + llvm::ArrayRef restores) { + if (restores.size() <= 1) { + return mlir::success(); + } + + // All restore ops must have the same prefix. + const mlir::Value prefix = + mlir::TF::RestoreV2Op(restores.front()).getPrefix(); + + std::vector restores_to_merge; + std::vector values_to_replace; + std::vector merged_tensor_names; + std::vector merged_shape_and_slices; + + std::vector restore_locs; + std::vector tensor_names_locs; + std::vector shape_and_slices_locs; + + for (mlir::TF::RestoreV2Op restore : restores) { + mlir::DenseStringElementsAttr tensor_names; + mlir::DenseStringElementsAttr shape_and_slices; + if (!mlir::matchPattern(restore, + mlir::m_Op( + mlir::matchers::m_Val(prefix), + mlir::m_Constant(&tensor_names), + mlir::m_Constant(&shape_and_slices)))) { + continue; + } + if (tensor_names.size() != restore.getNumResults() || + shape_and_slices.size() != restore.getNumResults()) { + return restore.emitOpError() + << "returns an inconsistent number of results"; + } + + restores_to_merge.push_back(restore); + llvm::append_range(values_to_replace, restore.getTensors()); + llvm::append_range(merged_tensor_names, + tensor_names.getValues()); + llvm::append_range(merged_shape_and_slices, + shape_and_slices.getValues()); + + restore_locs.push_back(restore.getLoc()); + tensor_names_locs.push_back(restore.getTensorNames().getLoc()); + shape_and_slices_locs.push_back(restore.getShapeAndSlices().getLoc()); + } + if (restores_to_merge.size() <= 1) { + return mlir::success(); + } + + // Insert the merged restore op right before the first restore op to be + // merged in order to keep the dominance property. + mlir::OpBuilder builder(restores_to_merge.front()); + + auto new_tensor_names = builder.create( + builder.getFusedLoc(tensor_names_locs), + GetStringTensorAttr(merged_tensor_names)); + auto new_shape_and_slices = builder.create( + builder.getFusedLoc(shape_and_slices_locs), + GetStringTensorAttr(merged_shape_and_slices)); + + auto new_restore = builder.create( + builder.getFusedLoc(restore_locs), + mlir::TypeRange(mlir::ValueRange(values_to_replace)), prefix, + new_tensor_names, new_shape_and_slices); + for (auto [old_value, new_value] : + llvm::zip(values_to_replace, new_restore.getTensors())) { + old_value.replaceAllUsesWith(new_value); + } + + for (mlir::TF::RestoreV2Op restore : restores_to_merge) { + restore.erase(); + } + return mlir::success(); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfRestoreMergingPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_splitting.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_splitting.cc new file mode 100644 index 00000000000000..130ca0a2e90b74 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_restore_splitting.cc @@ -0,0 +1,122 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFRESTORESPLITTINGPASS +#define GEN_PASS_DECL_TFRESTORESPLITTINGPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class TfRestoreSplittingPass + : public impl::TfRestoreSplittingPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + const mlir::WalkResult result = + func.walk([&](mlir::TF::RestoreV2Op restore) { + if (mlir::failed(SplitRestore(restore))) { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return signalPassFailure(); + } + } + + private: + mlir::DenseStringElementsAttr GetStringTensorAttr( + llvm::ArrayRef values) { + const int size = values.size(); + const auto type = mlir::RankedTensorType::get( + {size}, mlir::TF::StringType::get(&getContext())); + return mlir::DenseStringElementsAttr::get(type, values); + } + + // Splits the `tf.RestoreV2` op into per-variable restore ops if its + // `tensor_name` and `shape_and_slices` are constant. + mlir::LogicalResult SplitRestore(mlir::TF::RestoreV2Op restore) { + mlir::DenseStringElementsAttr tensor_names; + mlir::DenseStringElementsAttr shape_and_slices; + if (!mlir::matchPattern(restore, + mlir::m_Op( + /*prefix=*/mlir::matchers::m_Any(), + mlir::m_Constant(&tensor_names), + mlir::m_Constant(&shape_and_slices)))) { + return mlir::success(); + } + if (tensor_names.size() != restore.getNumResults() || + shape_and_slices.size() != restore.getNumResults()) { + return restore.emitOpError() + << "returns an inconsistent number of results"; + } + + mlir::OpBuilder builder(restore); + for (auto [tensor_name, shape_and_slice, result] : + llvm::zip(tensor_names.getValues(), + shape_and_slices.getValues(), + restore.getTensors())) { + auto new_tensor_names = + builder.create(restore.getTensorNames().getLoc(), + GetStringTensorAttr({tensor_name})); + + auto new_shape_and_slices = builder.create( + restore.getShapeAndSlices().getLoc(), + GetStringTensorAttr({shape_and_slice})); + + auto new_restore = builder.create( + restore.getLoc(), mlir::TypeRange({result.getType()}), + restore.getPrefix(), new_tensor_names, new_shape_and_slices); + result.replaceAllUsesWith(new_restore.getTensors()[0]); + } + + restore.erase(); + return mlir::success(); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfRestoreSplittingPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD index eb61561911f2b4..bec0d45d6b1525 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -51,6 +51,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_n_z_inc_gen", "//tensorflow/compiler/mlir/tensorflow:tensorflow_tfrt_ops_inc_gen", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], ) @@ -65,6 +66,7 @@ cc_library( ":util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_tfrt_ops_inc_gen", "//tensorflow/compiler/mlir/tfrt:constants", "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options", "//tensorflow/compiler/mlir/tfrt:transform_utils", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index be81e54ece871f..a9dca8adca598c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h.inc" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tfrt/constants.h" #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" @@ -322,6 +323,26 @@ class GetResourceOpConversion final } }; +// Convert tf.IfrtLoadVariableOp to tf_mlrt.IfrtLoadVariableOp +class IfrtLoadVariableOpConversion + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::IfrtLoadVariableOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector result_types( + op->getNumResults(), rewriter.getType()); + auto new_op = rewriter.create( + op.getLoc(), result_types, adaptor.getOperands()[0], + op.getDeviceShardingConfigProtoTextAttr(), op.getNameAttr()); + rewriter.replaceOp(op, new_op); + + return mlir::success(); + } +}; + std::optional DecodeLongName(mlir::Location loc) { if (auto name_loc = loc.dyn_cast()) { return name_loc.getName().str(); @@ -1167,8 +1188,8 @@ class TfToMlrtConversionPass // Order the list of added ops alphabetically. patterns.add(&context, &type_converter_, &symbol_table); patterns.add(&context); + SetResourceOpConversion, IfrtLoadVariableOpConversion, + TFAwaitOpConversion, TFPromiseOpConversion>(&context); patterns.add(type_converter_, &context); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc index fb110fb01f2ef1..69c8b08dcbc0b1 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -29,13 +30,15 @@ bool UseFallback(mlir::Operation *op) { // TODO(b/173017701): have a centralized place to hold the information // whether a TF op should be lowered to FallbackExecute op. + // TODO(b/319045348): Define trait to reflect that IfrtLoadVariableOp has no + // TF kernels so that we don't need to check every op here. // LINT.IfChange(fallback_allow_list) - return !llvm::isa(op); + return !llvm::isa< + mlir::TF::_TfrtSetResourceOp, mlir::TF::_TfrtGetResourceOp, + mlir::TF::BatchFunctionOp, mlir::TF::CaseOp, mlir::TF::IfrtLoadVariableOp, + mlir::TF::StatefulPartitionedCallOp, mlir::TF::PartitionedCallOp, + mlir::TF::LegacyCallOp, mlir::TF::IfOp, mlir::TF::WhileOp, + mlir::TF::TPUCompileMlirAndExecuteOp>(op); // LINT.ThenChange(tf_to_mlrt.cc:fallback_allow_list) } diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 974bc1a56c938f..66aee10db7e050 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -238,7 +238,7 @@ Status ConvertTfMlirToRuntimeExecutable( TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_TPU_JIT, + module, /*is_supported_by_replicated_brige*/ true, /*is_in_fallback_enabled_mode=*/VLOG_IS_ON(1))); TF_RETURN_IF_ERROR( @@ -257,7 +257,7 @@ Status ConvertTfMlirToRuntimeExecutable( } else if (options.device_target == TfrtDeviceInfraTarget::kGpu) { TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_GPU_JIT, + module, /*is_supported_by_replicated_brige*/ false, /*is_in_fallback_enabled_mode=*/false)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD index b065f29da4ddfc..bc887cdfc966f9 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD @@ -25,6 +25,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index 1953ddd3d93997..98cb26acdba8fa 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 44300ff2209f0d..6cfe2be883e6ca 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -48,6 +48,7 @@ cc_library( "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:gpu_passes", # fixdeps: keep "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/core:lib", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", @@ -56,8 +57,8 @@ cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ComplexToStandard", + "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUToLLVMIRTranslation", @@ -67,6 +68,7 @@ cc_library( "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:Parser", @@ -76,9 +78,10 @@ cc_library( "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:SCFToGPU", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeToStandard", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorToLLVM", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:all_passes", # fixdeps: keep "@local_xla//xla/mlir_hlo:mhlo_passes", @@ -162,11 +165,21 @@ cc_library( "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/platform:refcount", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:JITLink", + "@llvm-project//llvm:OrcShared", "@llvm-project//llvm:Support", "@llvm-project//mlir:ExecutionEngine", "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:mlir_c_runner_utils", "@llvm-project//mlir:mlir_runner_utils", "@local_xla//xla/stream_executor", ], @@ -178,8 +191,13 @@ cc_library( hdrs = ["tf_jit_cache.h"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core/framework:resource_base", + "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Support", "@llvm-project//mlir:ExecutionEngine", + "@local_tsl//tsl/platform:thread_annotations", ], ) @@ -197,6 +215,7 @@ cc_library( ]), deps = [ "//tensorflow/core:framework", + "//tensorflow/core/framework:resource_base", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:refcount", @@ -206,6 +225,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:mlir_runner_utils", "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:thread_annotations", "@local_xla//xla/stream_executor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc index b5537741529d06..4c05366e42af26 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc @@ -119,8 +119,8 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, llvm::StringRef host_triple, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors, int64_t max_supported_rank, - bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + llvm::ArrayRef unroll_factors, bool print_ptx, + bool print_llvmir, bool enable_ftz, bool index_64bit, bool jit_compile, bool jit_i64_indexed_for_large_tensors) { // Read TF code. std::string hlo_code; @@ -138,9 +138,9 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, GenerateKernelForHloCode(context, hlo_code, architectures, tile_sizes, - unroll_factors, max_supported_rank, print_ptx, - print_llvmir, enable_ftz, index_64bit, - jit_compile, jit_i64_indexed_for_large_tensors, + unroll_factors, print_ptx, print_llvmir, + enable_ftz, index_64bit, jit_compile, + jit_i64_indexed_for_large_tensors, /*apply_cl_options=*/true)); // Get binary. @@ -186,11 +186,6 @@ int main(int argc, char** argv) { llvm::cl::list architectures( "arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); - llvm::cl::opt max_supported_rank( - "max-supported-rank", - llvm::cl::desc("maximum supported rank to be guaranteed by rank " - "specialization lowering"), - llvm::cl::init(5)); llvm::cl::list tile_sizes( "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); @@ -222,8 +217,8 @@ int main(int argc, char** argv) { auto status = tensorflow::kernel_gen::Run( input_file, output_file, host_triple, architectures, tile_sizes, - unroll_factors, max_supported_rank, print_ptx, print_llvmir, enable_ftz, - index_64bit, jit_compile, jit_i64_indexed_for_large_tensors); + unroll_factors, print_ptx, print_llvmir, enable_ftz, index_64bit, + jit_compile, jit_i64_indexed_for_large_tensors); if (!status.ok()) { LOG(ERROR) << status; return 1; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index 8856ae73b0b4bf..d8e7617cc352ba 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -284,7 +284,6 @@ def TFFramework_JITCompileFromStrOp : TFFramework_Op<"jit_compile_from_str", StrAttr:$code, I64ArrayAttr:$tileSizes, I64ArrayAttr:$unrollFactors, - I64Attr:$maxSupportedRank, BoolAttr:$enableFtz, BoolAttr:$index64Bit, BoolAttr:$cpuCodegen diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 8bf068241f83fa..c8969f429e1805 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -23,31 +23,36 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "llvm/ADT/ArrayRef.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Arith/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project @@ -64,7 +69,10 @@ limitations under the License. #include "xla/mlir_hlo/transforms/gpu_passes.h" #include "xla/mlir_hlo/transforms/passes.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace kernel_gen { @@ -113,8 +121,7 @@ bool IsSmallAlloc(Value alloc) { Status LowerHloToJITInvocation(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, - bool index_64bit, + bool enable_ftz, bool index_64bit, bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) { mlir::PassManager pm(module.getContext()); @@ -122,8 +129,7 @@ Status LowerHloToJITInvocation(mlir::ModuleOp module, pm.addNestedPass( mlir::kernel_gen::transforms::CreateFuncToJITInvocationPass( - tile_sizes, unroll_factors, max_supported_rank, enable_ftz, - index_64bit, + tile_sizes, unroll_factors, enable_ftz, index_64bit, /*cpu_codegen=*/false, jit_i64_indexed_for_large_tensors)); pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass()); pm.addNestedPass( @@ -143,8 +149,7 @@ Status LowerHloToJITInvocation(mlir::ModuleOp module, Status LowerHlotoLoops(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, + llvm::ArrayRef unroll_factors, bool enable_ftz, bool index_64bit, bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) { mlir::PassManager pm(module.getContext()); @@ -152,14 +157,11 @@ Status LowerHlotoLoops(mlir::ModuleOp module, if (jit_i64_indexed_for_large_tensors) { pm.addNestedPass( mlir::kernel_gen::transforms::CreateFuncToJITInvocationPass( - tile_sizes, unroll_factors, max_supported_rank, enable_ftz, - index_64bit, + tile_sizes, unroll_factors, enable_ftz, index_64bit, /*cpu_codegen=*/false, /*jit_i64_indexed_for_large_tensors=*/true)); } - pm.addNestedPass(mlir::mhlo::createRankSpecializationClusterPass()); - pm.addNestedPass( - mlir::mhlo::createRankSpecializationToSCFPass(max_supported_rank)); + pm.addNestedPass(mlir::mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); @@ -409,6 +411,8 @@ StatusOr> SetupContextAndParseModule( mlir::MLIRContext& context, llvm::StringRef tf_code) { mlir::DialectRegistry registry; registry.insert(); + registry.insert(); registry.insert(); mlir::registerBuiltinDialectTranslation(registry); mlir::registerGPUDialectTranslation(registry); @@ -418,9 +422,10 @@ StatusOr> SetupContextAndParseModule( context.appendDialectRegistry(registry); mlir::OwningOpRef module = mlir::parseSourceString(tf_code, &context); - if (!module) + if (!module) { return tensorflow::Status(absl::StatusCode::kInvalidArgument, "invalid kernel IR"); + } return module; } @@ -428,9 +433,9 @@ StatusOr> GenerateKernelForHloCode( mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool print_ptx, bool print_llvmir, - bool enable_ftz, bool index_64bit, bool jit_compile, - bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) { + bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + bool jit_compile, bool jit_i64_indexed_for_large_tensors, + bool apply_cl_options) { if (jit_compile && jit_i64_indexed_for_large_tensors) { return tensorflow::Status( absl::StatusCode::kInvalidArgument, @@ -446,14 +451,12 @@ StatusOr> GenerateKernelForHloCode( assert(!jit_i64_indexed_for_large_tensors && "expect to have reported an error earlier"); TF_RETURN_IF_ERROR(LowerHloToJITInvocation( - module.get(), tile_sizes, unroll_factors, max_supported_rank, - enable_ftz, index_64bit, + module.get(), tile_sizes, unroll_factors, enable_ftz, index_64bit, /*jit_i64_indexed_for_large_tensors=*/false, apply_cl_options)); } else { - TF_RETURN_IF_ERROR( - LowerHlotoLoops(module.get(), tile_sizes, unroll_factors, - max_supported_rank, enable_ftz, index_64bit, - jit_i64_indexed_for_large_tensors, apply_cl_options)); + TF_RETURN_IF_ERROR(LowerHlotoLoops( + module.get(), tile_sizes, unroll_factors, enable_ftz, index_64bit, + jit_i64_indexed_for_large_tensors, apply_cl_options)); TF_RETURN_IF_ERROR( LowerLoopsToGPU(module.get(), index_64bit, apply_cl_options)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index ac8666874224aa..f92ff42405db38 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/core/platform/statusor.h" namespace tensorflow { @@ -44,9 +45,9 @@ StatusOr> GenerateKernelForHloCode( mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool print_ptx, bool print_llvmir, - bool enable_ftz, bool index_64bit, bool jit_compile, - bool jit_i64_indexed_for_large_tensors, bool apply_cl_options); + bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + bool jit_compile, bool jit_i64_indexed_for_large_tensors, + bool apply_cl_options); } // namespace kernel_gen } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir index d6a7c2698f15ea..e6404a84319c0d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir @@ -80,7 +80,7 @@ func.func @jit_compile_from_str(%ctx : !tf_framework.op_kernel_context) // CHECK: return %[[RES]] %0 = tf_framework.jit_compile_from_str "placeholder" { architectures = ["sm_123", "sm_456"], tileSizes = [1, 2, 3], - unrollFactors = [4], maxSupportedRank = 3 : i64, enableFtz = false, + unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %0 : !tf_framework.jit_callable } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir index 7da7a482f63774..c1c7a1f2ca9c00 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/func_to_jit_invocations.mlir @@ -1,11 +1,11 @@ // RUN: kernel-gen-opt %s --split-input-file \ // RUN: --func-to-jit-invocation="tile-sizes=1,2,3 unroll-factors=3,2,1 \ -// RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false" | \ +// RUN: enable-ftz=false cpu-codegen=false" | \ // RUN: FileCheck %s // RUN: kernel-gen-opt %s --split-input-file \ // RUN: --func-to-jit-invocation="tile-sizes=1,2,3 unroll-factors=3,2,1 \ -// RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false \ +// RUN: enable-ftz=false cpu-codegen=false \ // RUN: jit_i64_indexed_for_large_tensors=true" | \ // RUN: FileCheck %s --check-prefix=CHECK-JFLT @@ -30,7 +30,6 @@ func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SAME: { // CHECK-SAME: cpuCodegen = false // CHECK-SAME: enableFtz = false -// CHECK-SAME: maxSupportedRank = 32 : i64 // CHECK-SAME: tileSizes = [1, 2, 3] // CHECK-SAME: unrollFactors = [3, 2, 1] // CHECK-SAME: } @@ -49,7 +48,6 @@ func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-JFLT-SAME: cpuCodegen = false // CHECK-JFLT-SAME: enableFtz = false // CHECK-JFLT-SAME: index64Bit = true -// CHECK-JFLT-SAME: maxSupportedRank = 32 // CHECK-JFLT-SAME: tileSizes = [1, 2, 3] // CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] // CHECK-JFLT: %[[JIT_0:.*]] = tf_framework.jit_execute %[[JIT]](%[[ARG0]]) @@ -82,7 +80,6 @@ func.func @binary_sub(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*x // CHECK-SAME: { // CHECK-SAME: cpuCodegen = false // CHECK-SAME: enableFtz = false -// CHECK-SAME: maxSupportedRank = 32 : i64 // CHECK-SAME: tileSizes = [1, 2, 3] // CHECK-SAME: unrollFactors = [3, 2, 1] // CHECK-SAME: } @@ -114,7 +111,6 @@ func.func @binary_sub(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*x // CHECK-JFLT-SAME: { // CHECK-JFLT-SAME: cpuCodegen = false // CHECK-JFLT-SAME: enableFtz = false -// CHECK-JFLT-SAME: maxSupportedRank = 32 : i64 // CHECK-JFLT-SAME: tileSizes = [1, 2, 3] // CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] // CHECK-JFLT-SAME: } @@ -149,7 +145,6 @@ func.func @reciprocal(%arg0: tensor<*xf32>) // CHECK-SAME: cpuCodegen = false, // CHECK-SAME: enableFtz = false, // CHECK-SAME: index64Bit = false, -// CHECK-SAME: maxSupportedRank = 32 : i64, // CHECK-SAME: tileSizes = [1, 2, 3], // CHECK-SAME: unrollFactors = [3, 2, 1] // CHECK-SAME: } @@ -168,7 +163,6 @@ func.func @reciprocal(%arg0: tensor<*xf32>) // CHECK-JFLT-SAME: cpuCodegen = false // CHECK-JFLT-SAME: enableFtz = false // CHECK-JFLT-SAME: index64Bit = true -// CHECK-JFLT-SAME: maxSupportedRank = 32 // CHECK-JFLT-SAME: tileSizes = [1, 2, 3] // CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] // CHECK-JFLT: %[[JIT_0:.*]] = tf_framework.jit_execute %[[JIT]](%[[ARG0]]) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir index a1b35ccecc993e..686b34e0d138db 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2.mlir @@ -1,7 +1,131 @@ // RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func.func @AddV2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) - -> tensor<*xf32> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> +func.func @AddV2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xf32> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xf32>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xf32>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_add %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xf32> + scf.yield %cast : tensor<*xf32> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xf32>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xf32>) -> tensor + %20 = chlo.broadcast_add %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xf32> + scf.yield %cast : tensor<*xf32> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xf32>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_add %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xf32> + scf.yield %cast : tensor<*xf32> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xf32>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_add %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xf32>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_add %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xf32>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_add %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xf32>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_add %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xf32>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xf32>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_add %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xf32> + scf.yield %cast_1 : tensor<*xf32> + } + scf.yield %31 : tensor<*xf32> + } + scf.yield %29 : tensor<*xf32> + } + scf.yield %27 : tensor<*xf32> + } + scf.yield %25 : tensor<*xf32> + } + scf.yield %18 : tensor<*xf32> + } + scf.yield %16 : tensor<*xf32> + } + %10 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xf32> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %13 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir index ee01fe543eaac4..f38a2dca1bc8cd 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/add_v2_unsigned.mlir @@ -1,7 +1,131 @@ // RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func.func @AddV2(%arg0: tensor<*xui32>, %arg1: tensor<*xui32>) - -> tensor<*xui32> attributes {tf_entry, llvm.emit_c_interface} { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xui32>, tensor<*xui32>) -> tensor<*xui32> - return %0 : tensor<*xui32> +func.func @AddV2(%arg0: tensor<*xui32>, %arg1: tensor<*xui32>) -> tensor<*xui32> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xui32>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xui32>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_add %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xui32>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xui32>) -> tensor + %20 = chlo.broadcast_add %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xui32>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_add %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xui32>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_add %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xui32>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_add %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xui32>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_add %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xui32>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_add %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_add %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } + scf.yield %31 : tensor<*xui32> + } + scf.yield %29 : tensor<*xui32> + } + scf.yield %27 : tensor<*xui32> + } + scf.yield %25 : tensor<*xui32> + } + scf.yield %18 : tensor<*xui32> + } + scf.yield %16 : tensor<*xui32> + } + %10 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xui32>, tensor) -> tensor<*xui32> + return %13 : tensor<*xui32> } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir new file mode 100644 index 00000000000000..1facc06ee500e9 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/minimum.mlir @@ -0,0 +1,131 @@ +// RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 + +func.func @Minimum_GPU_DT_UINT32_DT_UINT32(%arg0: tensor<*xui32>, %arg1: tensor<*xui32>) -> tensor<*xui32> attributes {llvm.emit_c_interface, tf_entry} { + %0 = shape.const_shape [1, 1, 1, 1, 1] : tensor<5xindex> + %c5 = arith.constant 5 : index + %1 = shape.const_shape [1, 1, 1, 1] : tensor<4xindex> + %c4 = arith.constant 4 : index + %2 = shape.const_shape [1, 1, 1] : tensor<3xindex> + %c3 = arith.constant 3 : index + %3 = shape.const_shape [1, 1] : tensor<2xindex> + %c2 = arith.constant 2 : index + %4 = shape.const_shape [1] : tensor<1xindex> + %c1 = arith.constant 1 : index + %5 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %6 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %7 = shape.num_elements %5 : tensor -> index + %8 = arith.cmpi eq, %7, %c1 : index + %9 = scf.if %8 -> (tensor<*xui32>) { + %14 = shape.num_elements %6 : tensor -> index + %from_elements = tensor.from_elements %14 : tensor<1xindex> + %15 = mhlo.reshape %arg0 : (tensor<*xui32>) -> tensor + %16 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %17 = chlo.broadcast_minimum %15, %16 : (tensor, tensor) -> tensor + %cast = tensor.cast %17 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %14 = shape.num_elements %6 : tensor -> index + %15 = arith.cmpi eq, %14, %c1 : index + %16 = scf.if %15 -> (tensor<*xui32>) { + %17 = shape.num_elements %5 : tensor -> index + %from_elements = tensor.from_elements %17 : tensor<1xindex> + %18 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %19 = mhlo.reshape %arg1 : (tensor<*xui32>) -> tensor + %20 = chlo.broadcast_minimum %18, %19 : (tensor, tensor) -> tensor + %cast = tensor.cast %20 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %17 = shape.shape_eq %5, %6 : tensor, tensor + %18 = scf.if %17 -> (tensor<*xui32>) { + %19 = shape.any %5, %6 : tensor, tensor -> tensor + %20 = shape.num_elements %19 : tensor -> index + %from_elements = tensor.from_elements %20 : tensor<1xindex> + %21 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %22 = mhlo.dynamic_reshape %arg1, %from_elements : (tensor<*xui32>, tensor<1xindex>) -> tensor + %23 = chlo.broadcast_minimum %21, %22 : (tensor, tensor) -> tensor + %cast = tensor.cast %23 : tensor to tensor<*xui32> + scf.yield %cast : tensor<*xui32> + } else { + %19:2 = chlo.minimum_broadcast_shapes %5, %6 : tensor, tensor -> tensor, tensor + %20 = shape.rank %19#0 : tensor -> index + %21 = shape.rank %19#1 : tensor -> index + %22 = arith.cmpi sgt, %20, %21 : index + %23 = arith.select %22, %20, %21 : index + %24 = arith.cmpi ule, %23, %c1 : index + %25 = scf.if %24 -> (tensor<*xui32>) { + %26 = shape.broadcast %19#0, %4 : tensor, tensor<1xindex> -> tensor + %cast = tensor.cast %26 : tensor to tensor<1xindex> + %27 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<1xindex>) -> tensor + %28 = shape.broadcast %19#1, %4 : tensor, tensor<1xindex> -> tensor + %cast_0 = tensor.cast %28 : tensor to tensor<1xindex> + %29 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<1xindex>) -> tensor + %30 = chlo.broadcast_minimum %27, %29 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %30 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %26 = arith.cmpi ule, %23, %c2 : index + %27 = scf.if %26 -> (tensor<*xui32>) { + %28 = shape.broadcast %19#0, %3 : tensor, tensor<2xindex> -> tensor + %cast = tensor.cast %28 : tensor to tensor<2xindex> + %29 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<2xindex>) -> tensor + %30 = shape.broadcast %19#1, %3 : tensor, tensor<2xindex> -> tensor + %cast_0 = tensor.cast %30 : tensor to tensor<2xindex> + %31 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<2xindex>) -> tensor + %32 = chlo.broadcast_minimum %29, %31 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %32 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %28 = arith.cmpi ule, %23, %c3 : index + %29 = scf.if %28 -> (tensor<*xui32>) { + %30 = shape.broadcast %19#0, %2 : tensor, tensor<3xindex> -> tensor + %cast = tensor.cast %30 : tensor to tensor<3xindex> + %31 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<3xindex>) -> tensor + %32 = shape.broadcast %19#1, %2 : tensor, tensor<3xindex> -> tensor + %cast_0 = tensor.cast %32 : tensor to tensor<3xindex> + %33 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<3xindex>) -> tensor + %34 = chlo.broadcast_minimum %31, %33 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %34 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %30 = arith.cmpi ule, %23, %c4 : index + %31 = scf.if %30 -> (tensor<*xui32>) { + %32 = shape.broadcast %19#0, %1 : tensor, tensor<4xindex> -> tensor + %cast = tensor.cast %32 : tensor to tensor<4xindex> + %33 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<4xindex>) -> tensor + %34 = shape.broadcast %19#1, %1 : tensor, tensor<4xindex> -> tensor + %cast_0 = tensor.cast %34 : tensor to tensor<4xindex> + %35 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<4xindex>) -> tensor + %36 = chlo.broadcast_minimum %33, %35 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %36 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } else { + %32 = arith.cmpi ule, %23, %c5 : index + cf.assert %32, "Input for dynamic binary or n-ary op lowering was of a rank greater than 5" + %33 = shape.broadcast %19#0, %0 : tensor, tensor<5xindex> -> tensor + %cast = tensor.cast %33 : tensor to tensor<5xindex> + %34 = mhlo.dynamic_reshape %arg0, %cast : (tensor<*xui32>, tensor<5xindex>) -> tensor + %35 = shape.broadcast %19#1, %0 : tensor, tensor<5xindex> -> tensor + %cast_0 = tensor.cast %35 : tensor to tensor<5xindex> + %36 = mhlo.dynamic_reshape %arg1, %cast_0 : (tensor<*xui32>, tensor<5xindex>) -> tensor + %37 = chlo.broadcast_minimum %34, %36 : (tensor, tensor) -> tensor + %cast_1 = tensor.cast %37 : tensor to tensor<*xui32> + scf.yield %cast_1 : tensor<*xui32> + } + scf.yield %31 : tensor<*xui32> + } + scf.yield %29 : tensor<*xui32> + } + scf.yield %27 : tensor<*xui32> + } + scf.yield %25 : tensor<*xui32> + } + scf.yield %18 : tensor<*xui32> + } + scf.yield %16 : tensor<*xui32> + } + %10 = shape.shape_of %arg0 : tensor<*xui32> -> tensor + %11 = shape.shape_of %arg1 : tensor<*xui32> -> tensor + %12 = shape.broadcast %10, %11 : tensor, tensor -> tensor + %13 = mhlo.dynamic_reshape %9, %12 : (tensor<*xui32>, tensor) -> tensor<*xui32> + return %13 : tensor<*xui32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir index 8685ab17faee7a..2d3c8e6f5b9ef7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/tanh.mlir @@ -1,6 +1,11 @@ // RUN: hlo_to_kernel --input=%s --output=%t --unroll_factors=4 --tile_sizes=256 --arch=sm_70 -func.func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> attributes {tf_entry} { - %0 = mhlo.tanh %arg : tensor<*xf32> - return %0 : tensor<*xf32> +func.func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {tf_entry} { + %0 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %1 = shape.num_elements %0 : tensor -> index + %from_elements = tensor.from_elements %1 : tensor<1xindex> + %2 = mhlo.dynamic_reshape %arg0, %from_elements : (tensor<*xf32>, tensor<1xindex>) -> tensor + %3 = mhlo.tanh %2 : tensor + %4 = mhlo.dynamic_reshape %3, %0 : (tensor, tensor) -> tensor<*xf32> + return %4 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir index 22c9572dfd810c..da3ca471a857e4 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/isinf.mlir @@ -1,6 +1,6 @@ // RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | \ -// RUN: mlir-hlo-opt --mhlo-rank-specialization-cluster \ -// RUN: --mhlo-rank-specialization-to-scf --hlo-legalize-to-linalg \ +// RUN: mlir-hlo-opt \ +// RUN: --hlo-legalize-to-linalg \ // RUN: --empty-tensor-to-alloc-tensor \ // RUN: --computeop-and-func-bufferize --canonicalize | \ // RUN: kernel-gen-opt -allow-unregistered-dialect \ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir index ccd7a901fd0513..b5055e9ba4c8ab 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir @@ -82,7 +82,7 @@ func.func @jit_compile(%ctx : !tf_framework.op_kernel_context) func.func @jit_compile_from_str_wo_ctx() -> !tf_framework.jit_callable { %callable = tf_framework.jit_compile_from_str "placeholder" { architectures = ["sm_123", "sm_456"], tileSizes = [1, 2, 3], - unrollFactors = [4], maxSupportedRank = 3 : i64, enableFtz = false, + unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %callable : !tf_framework.jit_callable } @@ -92,7 +92,7 @@ func.func @jit_compile_from_str(%ctx : !tf_framework.op_kernel_context) -> !tf_framework.jit_callable { %callable = tf_framework.jit_compile_from_str %ctx , "placeholder" { architectures = ["sm_123", "sm_456"], tileSizes = [1, 2, 3], - unrollFactors = [4], maxSupportedRank = 3 : i64, enableFtz = false, + unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %callable : !tf_framework.jit_callable } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir deleted file mode 100644 index b8a227688c9dcb..00000000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tanh.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: tf-opt %s --test-tf-lower-tf --xla-legalize-tf | \ -// RUN: mlir-hlo-opt --mhlo-rank-specialization-cluster \ -// RUN: --mhlo-rank-specialization-to-scf --hlo-legalize-to-linalg \ -// RUN: --empty-tensor-to-alloc-tensor \ -// RUN: --computeop-and-func-bufferize --canonicalize | \ -// RUN: kernel-gen-opt -allow-unregistered-dialect \ -// RUN: --shape-to-descriptors \ -// RUN: --canonicalize --kernelgen-final-bufferize | \ -// RUN: FileCheck %s - -// Test whether all shape computations required for tanh can be lowered to -// the standard dialect, scf and descriptors. We check for a sparse pattern here, -// as each lowering pattern is already tested and we just care for the -// integration. -// TODO: Expand this pattern once things have stabilized. -// CHECK-LABEL: @tanh -func.func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: alloc - // CHECK: scf.for - // CHECK: memref.reshape - // CHECK: alloc - // CHECK: linalg.generic - // CHECK: memref.reshape - %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir deleted file mode 100644 index c30b21d8cd3162..00000000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-lmhlo.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | \ -// RUN: mlir-hlo-opt --mhlo-rank-specialization-cluster \ -// RUN: --mhlo-rank-specialization-to-scf --chlo-legalize-to-hlo \ -// RUN: --hlo-legalize-to-linalg --empty-tensor-to-alloc-tensor \ -// RUN: --computeop-and-func-bufferize | \ -// RUN: kernel-gen-opt --shape-to-descriptors \ -// RUN: --canonicalize --kernelgen-final-bufferize - -func.func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @tan(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Tan"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @tanh(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Tanh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @sin(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Sin"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @sinh(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Sinh"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -func.func @erf(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.Erf"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir index ed65470eb72728..4da5051a7bd2ac 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -151,7 +151,7 @@ func.func @is_valid_memref(%buf: memref) -> i1 { // ----- -// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_compile(!llvm.ptr, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr, i64, i1, i1, i1) -> !llvm.ptr +// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_compile(!llvm.ptr, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr, i1, i1, i1) -> !llvm.ptr // CHECK: llvm.mlir.global internal constant @[[CODE:jit_module_code_[0-9]+]]("placeholder\00") // CHECK: @jit_compile_from_str(%[[CTX:.*]]: !llvm.ptr) @@ -184,17 +184,16 @@ func.func @jit_compile_from_str(%ctx: !tf_framework.op_kernel_context) // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i64) // CHECK: llvm.store %[[C4]], %[[PTR]] - // CHECK-DAG: %[[MAX_RANK:.*]] = llvm.mlir.constant(3 : i64) // CHECK-DAG: %[[ENABLE_FTZ:.*]] = llvm.mlir.constant(false) // CHECK-DAG: %[[CPU_CODEGEN:.*]] = llvm.mlir.constant(false) // CHECK: %[[RES:.*]] = llvm.call @_mlir_ciface_tf_jit_compile // CHECK-SAME: %[[CTX]], %[[CODE_PTR]], // CHECK-SAME: %[[NUM_TILE_SIZES]], %[[TILE_SIZES]], // CHECK-SAME: %[[NUM_UNROLL_FACTORS]], %[[UNROLL_FACTORS]], - // CHECK-SAME: %[[MAX_RANK]], %[[ENABLE_FTZ]], %[[CPU_CODEGEN]] + // CHECK-SAME: %[[ENABLE_FTZ]], %[[CPU_CODEGEN]] // CHECK: llvm.return %[[RES]] %0 = tf_framework.jit_compile_from_str %ctx, "placeholder" { - tileSizes = [1, 2, 3], unrollFactors = [4], maxSupportedRank = 3 : i64, + tileSizes = [1, 2, 3], unrollFactors = [4], enableFtz = false, index64Bit = false, cpuCodegen = false } func.return %0 : !tf_framework.jit_callable } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index 34cbb4069ffeac..b2f717046ce1d1 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -19,22 +19,44 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/Support/Error.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project +#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.pb.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tsl/framework/allocator.h" #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) #include @@ -162,8 +184,8 @@ void InitializeLlvmCompiler() { llvm::Expected> Compile( const std::string code, llvm::SmallVectorImpl& architectures, llvm::SmallVectorImpl& tile_sizes, - llvm::SmallVectorImpl& unroll_factors, int64_t max_supported_rank, - bool enable_ftz, bool index_64bit) { + llvm::SmallVectorImpl& unroll_factors, bool enable_ftz, + bool index_64bit) { std::string cache_dir; if (const char* dir = getenv(kTFJitCacheDirEnvVar.data())) { cache_dir = dir; @@ -197,7 +219,6 @@ llvm::Expected> Compile( tensorflow::StatusOr> status_or_module = tensorflow::kernel_gen::GenerateKernelForHloCode( context, code, architectures, tile_sizes, unroll_factors, - max_supported_rank, /*print_ptx=*/false, /*print_llvmir=*/false, enable_ftz, index_64bit, /*jit_compile=*/false, @@ -261,8 +282,7 @@ llvm::SmallVector SmallVectorFromCArray(int64_t num_elements, extern "C" void* _mlir_ciface_tf_jit_compile( void* op_kernel_ctx, char* code, int64_t num_tile_sizes, int64_t* tile_sizes_ptr, int64_t num_unroll_factors, - int64_t* unroll_factors_ptr, int64_t max_supported_rank, bool enable_ftz, - bool index_64bit) { + int64_t* unroll_factors_ptr, bool enable_ftz, bool index_64bit) { // Get the resource manager. auto* ctx = static_cast(op_kernel_ctx); tensorflow::ResourceMgr* rm = ctx->resource_manager(); @@ -303,8 +323,8 @@ extern "C" void* _mlir_ciface_tf_jit_compile( // Lookup or compile the execution module. ExecutionEngine* engine = jit_cache->LookupOrCompile(code, [&]() { - return Compile(code, architectures, tile_sizes, unroll_factors, - max_supported_rank, enable_ftz, index_64bit); + return Compile(code, architectures, tile_sizes, unroll_factors, enable_ftz, + index_64bit); }); if (engine == nullptr) { ReportError(op_kernel_ctx, ErrorCode::UNKNOWN, "JIT compilation failed."); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h index 3a7c879d2c64c9..a62dc2c7020ab7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h @@ -36,8 +36,7 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_report_error( extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_jit_compile( void* op_kernel_ctx, char* code, int64_t num_tile_sizes, int64_t* tile_sizes_ptr, int64_t num_unroll_factors, - int64_t* unroll_factors_ptr, int64_t max_supported_rank, bool enable_ftz, - bool index_64bit); + int64_t* unroll_factors_ptr, bool enable_ftz, bool index_64bit); extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_jit_execute( void* op_kernel_ctx, void* callable, void* result, int64_t num_args, diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.cc index 6d2e954595645f..0d3b4f2d4a0eaf 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/stream.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -31,6 +30,10 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif + static void ReportInternalError(tensorflow::OpKernelContext *ctx, const std::string msg) { if (ctx == nullptr) { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h index 79e2c0cc530552..be1325c5a4dc16 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h @@ -18,9 +18,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "mlir/ExecutionEngine/RunnerUtils.h" // from @llvm-project +#include "tensorflow/core/framework/resource_base.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" #include "tsl/platform/hash.h" +#include "tsl/platform/thread_annotations.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc index 4d69d754a0a8b8..b2e3cc19b581c5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc @@ -19,6 +19,11 @@ limitations under the License. #include #include +#include "llvm/Support/Error.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" + namespace mlir { namespace kernel_gen { namespace tf_framework { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h index 414e7954be271e..f03e5778f6c8c5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h @@ -21,7 +21,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project +#include "tensorflow/core/framework/resource_base.h" #include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/thread_annotations.h" namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc index af0943ded1f8d8..32faed506e52b4 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc @@ -60,7 +60,7 @@ void RemoveCopyIfTargetOnlyRead(func::FuncOp func) { } continue; } - if (auto effect_interface = cast(user)) { + if (auto effect_interface = dyn_cast(user)) { if (reader) { at_most_one_read = false; } else { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc index d87271e161529b..60f8876f109553 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/func_to_jit_invocations.cc @@ -176,8 +176,7 @@ LogicalResult RewriteToLargeSizeJit(FuncOp op) { void PackJITCompileOp(tf_framework::JITCompileOp op, llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, + llvm::ArrayRef unroll_factors, bool enable_ftz, bool index_64bit, bool cpu_codegen) { IRRewriter rewriter(op.getContext()); Block *body = op.SingleBlock::getBody(); @@ -219,7 +218,6 @@ void PackJITCompileOp(tf_framework::JITCompileOp op, op, op->getResultTypes(), op.getCtx(), rewriter.getStringAttr(code), rewriter.getI64ArrayAttr(tile_sizes), rewriter.getI64ArrayAttr(unroll_factors), - rewriter.getI64IntegerAttr(max_supported_rank), rewriter.getBoolAttr(enable_ftz), rewriter.getBoolAttr(index_64bit), rewriter.getBoolAttr(cpu_codegen)); } @@ -231,12 +229,11 @@ struct FuncToJITInvocationPass : public impl::FuncToJITInvocationPassBase { explicit FuncToJITInvocationPass(llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, - bool index_64bit, bool cpu_codegen, + bool enable_ftz, bool index_64bit, + bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) { tile_sizes_ = tile_sizes; unroll_factors_ = unroll_factors; - max_supported_rank_ = max_supported_rank; enable_ftz_ = enable_ftz; index_64bit_ = index_64bit; cpu_codegen_ = cpu_codegen; @@ -255,9 +252,9 @@ struct FuncToJITInvocationPass } getOperation().walk([&](tf_framework::JITCompileOp op) { - PackJITCompileOp( - op, tile_sizes_, unroll_factors_, max_supported_rank_, enable_ftz_, - index_64bit_ || jit_i64_indexed_for_large_tensors_, cpu_codegen_); + PackJITCompileOp(op, tile_sizes_, unroll_factors_, enable_ftz_, + index_64bit_ || jit_i64_indexed_for_large_tensors_, + cpu_codegen_); }); } }; @@ -266,11 +263,11 @@ struct FuncToJITInvocationPass std::unique_ptr> CreateFuncToJITInvocationPass( llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool enable_ftz, bool index_64bit, - bool cpu_codegen, bool jit_i64_indexed_for_large_tensors) { + bool enable_ftz, bool index_64bit, bool cpu_codegen, + bool jit_i64_indexed_for_large_tensors) { return std::make_unique( - tile_sizes, unroll_factors, max_supported_rank, enable_ftz, index_64bit, - cpu_codegen, jit_i64_indexed_for_large_tensors); + tile_sizes, unroll_factors, enable_ftz, index_64bit, cpu_codegen, + jit_i64_indexed_for_large_tensors); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index b700189aec6608..37b799bb9666ea 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -112,12 +112,11 @@ class GpuKernelToBlobPass return tensorflow::errors::Internal( "Could not parse ROCm architecture prefix (expected gfx)"); } - std::string libdevice_dir = tensorflow::RocdlRoot(); auto llvm_module_copy = llvm::CloneModule(*llvmModule); auto hsaco_or = xla::gpu::amdgpu::CompileToHsaco( llvm_module_copy.get(), tensorflow::se::RocmComputeCapability{arch_str}, options, - libdevice_dir, options.DebugString()); + options.DebugString()); if (!hsaco_or.ok()) { return tensorflow::errors::Internal("Failure when generating HSACO"); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 355030d6009f18..45e248ceb904ff 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -66,8 +66,8 @@ std::unique_ptr> CreateBufferReusePass(); // framework. std::unique_ptr> CreateFuncToJITInvocationPass( llvm::ArrayRef tile_sizes = {}, - llvm::ArrayRef unroll_factors = {}, int64_t max_supported_rank = 5, - bool enable_ftz = false, bool index_64bit = false, bool cpu_codegen = false, + llvm::ArrayRef unroll_factors = {}, bool enable_ftz = false, + bool index_64bit = false, bool cpu_codegen = false, bool jit_i64_indexed_for_large_tensors = false); // Pass for applying LLVM legalization patterns. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 0aa80747ef5ebc..4f92be70d25397 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -52,8 +52,6 @@ def FuncToJITInvocationPass : Pass<"func-to-jit-invocation", "mlir::func::FuncOp "llvm::cl::ZeroOrMore">, ListOption<"unroll_factors_", "unroll-factors", "int64_t", "Unrolling in each tile dimension", "llvm::cl::ZeroOrMore">, - Option<"max_supported_rank_", "max-supported-rank", "int64_t", - /*default=*/"", "Max rank that this kernel supports">, Option<"enable_ftz_", "enable-ftz", "bool", /*default=*/"", "Enable the denormal flush to zero mode when generating code">, Option<"index_64bit_", "index_64bit", "bool", /*default=*/"", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index fe649e1edeb723..cffa5e7b44691e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -271,8 +271,6 @@ class JITCompileFromStrOpConverter ConvertIntegerArrayAttrToStackAllocatedArray( loc, rewriter.getI64Type(), rewriter.getI64Type(), op.getUnrollFactors(), &rewriter); - Value max_supported_rank = rewriter.create( - loc, rewriter.getI64Type(), op.getMaxSupportedRankAttr()); Value enable_ftz = rewriter.create( loc, rewriter.getI1Type(), op.getEnableFtzAttr()); Value index_64bit = rewriter.create( @@ -285,8 +283,8 @@ class JITCompileFromStrOpConverter op, getVoidPtrType(), tf_func_ref, llvm::ArrayRef({adaptor.getCtx(), jit_module_code, tile_sizes.first, tile_sizes.second, unroll_factors.first, - unroll_factors.second, max_supported_rank, enable_ftz, - index_64bit, cpu_codegen})); + unroll_factors.second, enable_ftz, index_64bit, + cpu_codegen})); return success(); } @@ -304,7 +302,6 @@ class JITCompileFromStrOpConverter /*int64_t* tile_sizes_ptr*/ ptr_ty, /*int64_t num_unroll_factors*/ i64_ty, /*int64_t* unroll_factors_ptr*/ ptr_ty, - /*int64_t max_supported_rank*/ i64_ty, /*bool enable_ftz*/ i1_ty, /*bool index_64bit*/ i1_ty, /*bool cpu_codegen*/ i1_ty}); diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index 50f81b588c5cd2..76a76a4ab82747 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -1018,7 +1018,7 @@ def f(x, y): # x: f32[b, 4] y: f32[2, b, 4] %2 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> %3 = stablehlo.constant dense<4> : tensor<1xi32> %4 = "stablehlo.concatenate"(%0, %2, %3) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - %5 = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi32>) -> tensor<2x?x4xf32> + %5 = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %4) {broadcast_dimensions = array} : (tensor, tensor<3xi32>) -> tensor<2x?x4xf32> %6 = stablehlo.add %5, %arg2 : (tensor<2x?x4xf32>, tensor<2x?x4xf32>) -> tensor<2x?x4xf32> return %5, %6 : tensor<2x?x4xf32>, tensor<2x?x4xf32> } diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index b91fb494667c5f..91ef722b52db86 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -298,6 +298,7 @@ cc_library( name = "trt_engine_resource_op_kernels", srcs = ["kernels/trt_engine_resource_ops.cc"], copts = tf_copts(), + features = ["-layering_check"], visibility = ["//tensorflow/core:__subpackages__"], deps = [ ":trt_allocator", @@ -994,6 +995,7 @@ tf_cuda_library( name = "trt_plugins", srcs = ["plugin/trt_plugin.cc"], hdrs = ["plugin/trt_plugin.h"], + features = ["-layering_check"], deps = [ "//tensorflow/core:framework_lite", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index c565cef1489532..4f6706c11b3234 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -944,8 +944,8 @@ tf_cc_test( "@local_xla//xla/client:local_client", "@local_xla//xla/service:compiler", "@local_xla//xla/service:platform_util", - "@local_xla//xla/stream_executor:multi_platform_manager", "@local_xla//xla/stream_executor:platform", + "@local_xla//xla/stream_executor:platform_manager", ], ) @@ -1171,14 +1171,15 @@ cc_library( "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:device_util", - "//tensorflow/compiler/mlir/tensorflow/transforms:bridge", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", "//tensorflow/compiler/mlir/tf2xla/api/v1:cluster_tf", "//tensorflow/compiler/mlir/tf2xla/api/v1:tf_dialect_to_executor", "//tensorflow/compiler/mlir/tf2xla/api/v2:cluster_tf", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_dialect_to_executor", + "//tensorflow/compiler/mlir/tf2xla/internal:mlir_bridge_pass_util", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/common_runtime:device_set", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index b39d6f8f5d2bed..b8d91294ca7b18 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -42,7 +42,7 @@ Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, TF_RETURN_IF_ERROR(flib_runtime->Instantiate( name_attr_list.name(), AttrSlice(&name_attr_list.attr()), &func_handle)); *fbody = flib_runtime->GetFunctionBody(func_handle); - return OkStatus(); + return absl::OkStatus(); } Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, @@ -57,7 +57,7 @@ Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, &func_handle)); fbodies->push_back(flib_runtime->GetFunctionBody(func_handle)); } - return OkStatus(); + return absl::OkStatus(); } Status CondConstInputIndices( @@ -84,7 +84,7 @@ Status CondConstInputIndices( const_input_idxs->push_back(i + 1); } } - return OkStatus(); + return absl::OkStatus(); } Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, @@ -133,7 +133,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, } } } - return OkStatus(); + return absl::OkStatus(); } else if (node.op() == "If" || node.op() == "StatelessIf") { const FunctionBody* fthen = nullptr; const FunctionBody* felse = nullptr; @@ -162,7 +162,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, const_input_idxs->push_back(i); } } - return OkStatus(); + return absl::OkStatus(); } else if (op_def != nullptr) { return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def, const_input_idxs); @@ -193,7 +193,7 @@ Status BackwardsConstAnalysis( !edge_filter_input) { VLOG(5) << "Using cached argument indices on graph " << &g; *compile_time_const_arg_indices = g.GetConstArgIndicesCache().value(); - return OkStatus(); + return absl::OkStatus(); } auto edge_filter = [&](const Edge& e) { return edge_filter_input ? edge_filter_input(e) : true; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 1d3b6d3fe873b0..4c7048d205a324 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -128,14 +128,14 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge)); } *pred = OutputTensor(pred_edge->src(), pred_edge->src_output()); - return OkStatus(); + return absl::OkStatus(); } Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { const Edge* val_edge; TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); *val = OutputTensor(val_edge->src(), val_edge->src_output()); - return OkStatus(); + return absl::OkStatus(); } bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs, @@ -394,7 +394,7 @@ Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, Status Conditional::AddMerge(Node* m) { merges_.insert(m); - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddSwitch(Node* s) { @@ -410,7 +410,7 @@ Status Conditional::AddSwitch(Node* s) { } switches_.insert(s); parent_->AddSwitchId(s->id()); - return OkStatus(); + return absl::OkStatus(); } Status Conditional::BuildArgumentNodes() { @@ -492,7 +492,7 @@ Status Conditional::BuildArgumentNodes() { } } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, @@ -741,7 +741,7 @@ Status Conditional::ExtractBodies(Graph* graph) { } } } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::BuildIfNode(Graph* graph, @@ -834,7 +834,7 @@ Status Conditional::BuildIfNode(Graph* graph, TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin(), predicate_)); - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddInputEdges( @@ -871,7 +871,7 @@ Status Conditional::AddInputEdges( for (Node* n : external_control_inputs_) { graph->AddControlEdge(n, if_node_); } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::AddOutputEdges( @@ -910,7 +910,7 @@ Status Conditional::AddOutputEdges( graph->AddControlEdge(if_node_, n); } - return OkStatus(); + return absl::OkStatus(); } Status Conditional::BuildAndReplace( @@ -918,7 +918,7 @@ Status Conditional::BuildAndReplace( std::unordered_map* merge_to_replacement) { VLOG(1) << "Build If and replace merge nodes " << NodesToString(this->merges_); - if (replaced_) return OkStatus(); + if (replaced_) return absl::OkStatus(); TF_RETURN_IF_ERROR(ExtractBodies(graph)); TF_RETURN_IF_ERROR(BuildArgumentNodes()); @@ -944,7 +944,7 @@ Status Conditional::BuildAndReplace( "Converting to If failed."); replaced_ = true; - return OkStatus(); + return absl::OkStatus(); } string Conditional::name() const { @@ -966,7 +966,7 @@ Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id)); state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); - return OkStatus(); + return absl::OkStatus(); } StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, @@ -1018,7 +1018,7 @@ Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { changed.erase(n); } } - return OkStatus(); + return absl::OkStatus(); } // Returns the most restrictive branch of two branches or neither. This is the @@ -1160,7 +1160,7 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { // Only Merge nodes with two inputs are supported, but if this is a redundant // merge, then the dead edge may already have been removed (if due to a // switch) and so the input count would be incorrect. - if (state_map_.IsDead(state_map_.LookupCondId(dst))) return OkStatus(); + if (state_map_.IsDead(state_map_.LookupCondId(dst))) return absl::OkStatus(); int data_inputs = 0; for (auto e : dst->in_edges()) { @@ -1183,7 +1183,7 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { dst->name(), " only has ", data_inputs, " inputs, while only merge nodes with two inputs supported."); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { @@ -1201,13 +1201,14 @@ Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { FormatNodeForError(*dst)); state_map_.ResetCondId(dst, id_or.value()); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { // Handle redundant merge nodes. A merge node is considered redundant if // one input edge is dead while the other has a value. - if (!state_map_.IsDead(state_map_.LookupCondId(node))) return OkStatus(); + if (!state_map_.IsDead(state_map_.LookupCondId(node))) + return absl::OkStatus(); const Edge* non_dead_edge = nullptr; for (auto e : node->in_edges()) { @@ -1239,7 +1240,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { : non_dead_edge->src_output(), dst_node, dst_port); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { @@ -1251,7 +1252,7 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { // (rather than boolean equivalence) and aimed at redundant switches as // currently generated by gradient code. StateMap::CondId dst_id = state_map_.LookupCondId(node); - if (state_map_.IsDead(dst_id)) return OkStatus(); + if (state_map_.IsDead(dst_id)) return absl::OkStatus(); BranchType b; OutputTensor pred; @@ -1272,7 +1273,7 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { } b = state_map_.FindBranchOf(dst_id, val); if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) - return OkStatus(); + return absl::OkStatus(); } VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " " @@ -1309,7 +1310,7 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port, dst_node, dst_input); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { @@ -1325,7 +1326,7 @@ Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { << " @ " << state_map_.AncestorStateToString(dst); if (VLOG_IS_ON(10)) DumpGraphWithCondState("it"); } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeCond::DetermineAncestorState(Node* dst) { @@ -1359,7 +1360,7 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { id = insert(id, src); } state_map_.ResetAncestorId(dst, id); - return OkStatus(); + return absl::OkStatus(); } void FunctionalizeCond::DeleteReachableAndDeadNodes( @@ -1504,7 +1505,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { // No merges mean no switch values consumed (as only considering values // fetchable as output of merge); DeleteReachableAndDeadNodes(merge_order); - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); @@ -1574,7 +1575,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { DeleteReachableAndDeadNodes(merge_order); - return OkStatus(); + return absl::OkStatus(); } void FunctionalizeCond::DumpGraphWithCondState(const string& name) { diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1d13170daa43e6..2bad3b58d34761 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -117,7 +117,8 @@ Status AddFunctionDefToGraphLibrary( // `graph->flib_def().default_registry()` which is done in the following line // (we have to use `LookUp` instead of `Contains` or `Find` because the latter // both don't check the default registry). - if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) return OkStatus(); + if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) + return absl::OkStatus(); const FunctionDef* new_fdef = fld->Find(func_name); DCHECK(new_fdef != nullptr); @@ -197,7 +198,7 @@ Status FunctionalizeControlFlowForNodeAssociatedFunctions( } } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeControlFlowForFunction( @@ -210,7 +211,7 @@ Status FunctionalizeControlFlowForFunction( // Convert the function to a graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = OkStatus(); + Status ret_status = absl::OkStatus(); auto cleanup_handle = gtl::MakeCleanup([&]() { auto s = flr->ReleaseHandle(handle); if (!s.ok()) { @@ -304,7 +305,7 @@ Status FunctionalizeControlFlow(Graph* graph, VLOG(2) << "FunctionalizeControlFlow (final): " << DumpGraphToFile("functionalize_final", *graph, library); - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, @@ -319,7 +320,7 @@ Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, include_functions)); graph.ToGraphDef(graph_def); std::swap(*graph_def->mutable_library(), function_lib); - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeControlFlowForXlaPass::Run( @@ -388,7 +389,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( DumpGraphToFile("functionalize_control_flow_after", *graph, options.flib_def); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index a116f905097989..25a08224c8b946 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -56,7 +56,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, *then_fn = *result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result)); *else_fn = *result; - return OkStatus(); + return absl::OkStatus(); } } return errors::NotFound("No If node found in graph"); @@ -317,7 +317,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, *cond = *result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result)); *body = *result; - return OkStatus(); + return absl::OkStatus(); } } return errors::NotFound("No While node found in graph"); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index 71e3f9f69e4445..3260fa139cb673 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -78,7 +78,7 @@ Status ExtractWhileLoopFrames( } } - return OkStatus(); + return absl::OkStatus(); } // Check that the graph has no cycle containing the given node. @@ -99,7 +99,7 @@ Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 0294c018e512db..7feb847f982c2e 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -103,7 +103,7 @@ Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame, output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); } } - return OkStatus(); + return absl::OkStatus(); } StatusOr BuildArgNode(Graph* graph, DataType type, int index) { @@ -206,7 +206,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), squash_src_outputs, &node_map, output)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, @@ -216,7 +216,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, VLOG(2) << "Skipping functionalization for frame " << frame->name << " because it has control flow nodes that are filtered out by " "the specified node filter."; - return OkStatus(); + return absl::OkStatus(); } VLOG(2) << "Frame " << frame->name << " before: " << DumpGraphToFile("functionalize_before", *graph, library); @@ -501,7 +501,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, VLOG(2) << "Frame " << frame->name << " after: " << DumpGraphToFile("functionalize_after", *graph, library); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -565,7 +565,7 @@ Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc index 0179168be93bcd..70c09bc84ac275 100644 --- a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -60,7 +60,7 @@ Status GetTestDevice(Session* session, string* test_device) { *test_device = found_gpu ? "GPU" : "CPU"; VLOG(2) << "Using test device " << *test_device; - return OkStatus(); + return absl::OkStatus(); } void FillZeros(Tensor* tensor) { diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 4914efdba55a07..23eb33224dc24b 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -114,7 +114,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, return errors::InvalidArgument("Invalid function argument"); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace Status GraphCompiler::Compile() { @@ -204,7 +204,7 @@ Status GraphCompiler::Compile() { } } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -221,7 +221,7 @@ Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, " does not have 'func' field set"); } *func = attr_value->func(); - return OkStatus(); + return absl::OkStatus(); } if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) { @@ -230,7 +230,7 @@ Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, func->set_name(FunctionLibraryDefinition::kGradientOp); } *func->mutable_attr() = node.def().attr(); - return OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index 68c576a52cba73..ac064805f1a470 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -106,7 +106,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, graph->RemoveEdge(edge); } } - return OkStatus(); + return absl::OkStatus(); } // Each fetch id identifies the positional output of some node. For each fetch @@ -138,7 +138,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, .Finalize(graph, &retval_node)); retval_nodes->insert(retval_node); } - return OkStatus(); + return absl::OkStatus(); } // RewriteAndPruneGraph identifies input and output edges (named by the feed and @@ -192,7 +192,7 @@ Status RewriteAndPruneGraph( ", missing feeds: ", absl::StrJoin(missing_feeds, ", "), ", missing fetches: ", absl::StrJoin(missing_fetches, ", ")); } - return OkStatus(); + return absl::OkStatus(); } // CollectArgNodes collects _Arg nodes from the graph, and performs basic @@ -224,7 +224,7 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { } arg_nodes->push_back(index_node.second); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -243,7 +243,7 @@ Status CreateXlaArgs(const Graph& graph, TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } - return OkStatus(); + return absl::OkStatus(); } void PopulateXlaArgs(const tf2xla::Config& config, @@ -306,7 +306,7 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); *graph = std::move(g); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/host_compute_metadata.proto b/tensorflow/compiler/tf2xla/host_compute_metadata.proto index 43ab371a217e6c..9e6eec2cddc99e 100644 --- a/tensorflow/compiler/tf2xla/host_compute_metadata.proto +++ b/tensorflow/compiler/tf2xla/host_compute_metadata.proto @@ -1,19 +1,21 @@ syntax = "proto3"; package tensorflow.tf2xla; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + option cc_enable_arenas = true; option java_outer_classname = "Tf2XlaProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.tf2xla"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - // TensorMetadata indicates the type and shape of a Tensor that is // part of a host compute transfer. message TensorMetadata { DataType type = 1; TensorShapeProto shape = 2; + int64 channel_id = 3; } // HostTransferMetadata describes a transfer either from host to device diff --git a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc index 80a43e2026d875..ba99f9b1297542 100644 --- a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc @@ -96,7 +96,6 @@ class CollectiveReduceV2Op : public XlaOpKernel { void operator=(const CollectiveReduceV2Op&) = delete; }; - REGISTER_XLA_OP(Name("CollectiveReduceV2") .CompileTimeConstantInput("group_key") .CompileTimeConstantInput("group_size"), @@ -106,4 +105,13 @@ REGISTER_XLA_OP(Name("CollectiveAssignGroupV2") .CompileTimeConstantInput("group_assignment"), MlirXlaOpKernel); +REGISTER_XLA_OP(Name("XlaReduceScatter") + .CompileTimeConstantInput("group_assignment") + .CompileTimeConstantInput("scatter_dimension"), + MlirXlaOpKernel); + +REGISTER_XLA_OP( + Name("XlaAllReduce").CompileTimeConstantInput("group_assignment"), + MlirXlaOpKernel); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index c2bf7f7f606432..f44d800481fe6a 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -148,7 +148,7 @@ Status CheckConvAttrs(const ConvOpAttrs& attrs) { attrs.dilations[input_dim]); } } - return OkStatus(); + return absl::OkStatus(); } // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes diff --git a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc index 2146284bd69864..92138c9663c556 100644 --- a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc @@ -277,7 +277,7 @@ class FusedConv2DInt8Op : public XlaOpKernel { } ctx->SetOutput(0, result); - return OkStatus(); + return absl::OkStatus(); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 59d196c351c99b..5877aea0269643 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -86,7 +86,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, *gather_output = xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); - return OkStatus(); + return absl::OkStatus(); } for (int64_t i = 0; i < num_index_dims; ++i) { @@ -152,7 +152,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, } *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes); - return OkStatus(); + return absl::OkStatus(); } Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, @@ -236,7 +236,7 @@ Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, /*indices_are_nd=*/false, context->expected_output_dtype(0), index_type, context->builder(), gather_output)); } - return OkStatus(); + return absl::OkStatus(); } class GatherOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 945ae96b46c327..b2d0b2e1d418f2 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -181,7 +181,7 @@ static Status ValidateShapes(XlaOpKernelContext* ctx, "Mismatch in resource of then and else branch for resource ", i); } } - return OkStatus(); + return absl::OkStatus(); } // TODO(b/35949885): There is duplication here with the handling of the diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index fbdcecdcf95dfd..60ef289567fe8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -53,8 +53,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -374,7 +374,7 @@ Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel( output_shape.IsTuple() ? xla::GetTupleElement(out, i) : out); } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -552,7 +552,7 @@ Status PopulateMetadataBufferIfNeeded(OpKernelContext& ctx, num_dimensions * sizeof(int32_t)); } } - return OkStatus(); + return absl::OkStatus(); } class FakeDeviceContext : public DeviceContext { @@ -569,8 +569,7 @@ Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, // Look up the platform only once, for a small performance gain. static Status* platform_status = nullptr; static se::Platform* platform = [&]() -> se::Platform* { - StatusOr p = - se::MultiPlatformManager::PlatformWithName("CUDA"); + StatusOr p = se::PlatformManager::PlatformWithName("CUDA"); if (!p.ok()) { platform_status = new Status(p.status()); return nullptr; @@ -708,7 +707,7 @@ Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, } TF_RETURN_IF_ERROR(ctx.status()); - return OkStatus(); + return absl::OkStatus(); } void GenericTfCallback(void* stream_handle, void** buffers, const char* opaque, diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 9a86b26fb8a623..258e0b0d47f6a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -98,7 +98,7 @@ class ListDiffOp : public XlaOpKernel { xla::ConstantR1(context->builder(), val_output)); context->SetOutput(1, xla::ConstantR1(context->builder(), idx_output)); - return OkStatus(); + return absl::OkStatus(); } template diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index fdd38e2f6beb32..a82cd2e3b85db1 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -55,7 +55,7 @@ static Status ValidateKernelSizes(const T& ksizes) { " must be positive but is ", ksizes[i]); } } - return OkStatus(); + return absl::OkStatus(); } template @@ -67,7 +67,7 @@ static Status ValidateStrides(const T& strides) { " must be positive but is ", strides[i]); } } - return OkStatus(); + return absl::OkStatus(); } // Superclass of pooling ops. diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index ce20763c6146eb..81282578bb2ee6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -101,7 +101,7 @@ class MaxOp : public XlaReductionOp { "Unsupported PrimitiveType in MaxOp: '", xla::PrimitiveType_Name(xla_reduction_type), "'"); } else { - return OkStatus(); + return absl::OkStatus(); } } diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 77062741c0f91f..55eaf3db8a7570 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -58,7 +58,7 @@ Status ValidateUpdateShape(const TensorShape& buffer_shape, }; if (updates_shape.dims() == 0 && broadcast_scalar_update) { - return OkStatus(); + return absl::OkStatus(); } if (updates_shape.dims() < batch_dim) return shape_err(); @@ -81,7 +81,7 @@ Status ValidateUpdateShape(const TensorShape& buffer_shape, return shape_err(); } } - return OkStatus(); + return absl::OkStatus(); } class ScatterNdOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 06f7160e392456..1cd296b349a9dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -42,7 +42,7 @@ Status TensorShapeToConstant(const TensorShape& input_shape, vec(i) = dim_size; } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc index c175a5584e8f2d..3c1ca648769d95 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc @@ -91,7 +91,7 @@ Status GetAndValidateAttributes(OpKernelConstruction* ctx, paddings.assign(expected_rank, 0); } - return OkStatus(); + return absl::OkStatus(); } std::vector GetSliceIndices(absl::Span num_partitions, @@ -174,10 +174,10 @@ class XlaSplitNDBaseOp : public XlaOpKernel { xla::Pad(input, xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0), padding_config)); - return OkStatus(); + return absl::OkStatus(); } else if (num_slices_ == 1) { ctx->SetOutput(/*index=*/0, input); - return OkStatus(); + return absl::OkStatus(); } // Slice shape with optional padding. @@ -242,7 +242,7 @@ class XlaSplitNDBaseOp : public XlaOpKernel { slice_limit_indices, slice_strides)); } } - return OkStatus(); + return absl::OkStatus(); } private: @@ -426,7 +426,7 @@ class XlaConcatNDBaseOp : public XlaOpKernel { output_shape.push_back(max_dim_size - paddings_[dim]); } - return OkStatus(); + return absl::OkStatus(); } std::vector num_concats_; diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index e74381aa6f24b4..8131769503086c 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -85,7 +85,7 @@ Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, actual_shape.DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } class StackOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index 6fbb413b46bc3b..d5746a5bfd729b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -142,7 +142,7 @@ Status CheckStateShape(xla::RandomAlgorithm alg, const TensorShape& shape) { return errors::InvalidArgument("The size of the state must be at least ", min_state_size, "; got ", state_size); } - return OkStatus(); + return absl::OkStatus(); } StatusOr ResolveAlg(int alg_id) { @@ -227,7 +227,7 @@ Status CompileImpl( var = BitcastConvertType(var, state_element_type); TF_RETURN_IF_ERROR( ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); - return OkStatus(); + return absl::OkStatus(); } class StatefulUniformOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 22614039459a90..aca9973e118c14 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -84,7 +84,7 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, shape.DebugString()); } } - return OkStatus(); + return absl::OkStatus(); } // Checks that the TensorArray 'resource' has been initialized, and has type @@ -106,14 +106,14 @@ Status CheckTensorArrayIsInitialized(const string& op_name, " but op has dtype ", DataTypeString(dtype), "."); } - return OkStatus(); + return absl::OkStatus(); } Status GetTensorArrayShape(const XlaResource* resource, xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); shape->InsertDim(0, resource->max_array_size()); - return OkStatus(); + return absl::OkStatus(); } // Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 7e4c16ca189de3..8f8e6bd90ca448 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -119,19 +119,19 @@ Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input, bool is_compile_time_constant = is_compile_time_constant_or.value(); if (!is_compile_time_constant) { *got_shape = false; - return OkStatus(); + return absl::OkStatus(); } PartialTensorShape partial_shape; TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape)); if (!partial_shape.IsFullyDefined()) { *got_shape = false; - return OkStatus(); + return absl::OkStatus(); } *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes()); *got_shape = true; - return OkStatus(); + return absl::OkStatus(); } class TensorListReserveOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 575a3f400d899b..d7a1b5f970561a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -117,7 +117,7 @@ bool IsTensorListInput(XlaOpKernelContext* ctx, int index) { Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); *is_initialized = list_shape.IsTuple(); - return OkStatus(); + return absl::OkStatus(); } Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { @@ -128,14 +128,14 @@ Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { } TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2); - return OkStatus(); + return absl::OkStatus(); } Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, xla::XlaOp* output_list) { TF_RET_CHECK(buffer.builder()); *output_list = xla::Tuple(buffer.builder(), {buffer, push_index}); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { @@ -146,7 +146,7 @@ Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { } TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { @@ -156,7 +156,7 @@ Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { return errors::InvalidArgument("TensorList is not initialized"); } *buffer = xla::GetTupleElement(list, 0); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { @@ -168,7 +168,7 @@ Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); *push_index = xla::GetTupleElement(list, tuple_size - 1); - return OkStatus(); + return absl::OkStatus(); } Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, @@ -187,7 +187,7 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, } result_parts.push_back(push_index); *result = xla::Tuple(list.builder(), result_parts); - return OkStatus(); + return absl::OkStatus(); } xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, @@ -222,7 +222,7 @@ Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, *leading_dim = list_shape.dimensions(0); *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0); } - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListShapeFromElementTensorListShape( @@ -244,7 +244,7 @@ Status GetTensorListShapeFromElementTensorListShape( shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); - return OkStatus(); + return absl::OkStatus(); } Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, @@ -267,7 +267,7 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, shapes.push_back(xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); - return OkStatus(); + return absl::OkStatus(); } Status CreateZerosTensorListWithShape( @@ -296,7 +296,7 @@ Status CreateZerosTensorListWithShape( .element_type() == xla::S32); elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32))); *list = xla::Tuple(b, elements); - return OkStatus(); + return absl::OkStatus(); } Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, @@ -330,7 +330,7 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, ", expected: ", list_shape.DebugString()); } *initialized_list = list; - return OkStatus(); + return absl::OkStatus(); } else { // Prepare dynamic dimension dimensions for zero tensor list. The dynamic // sizes are created by reading the dynamic dimension size of sub-elements. @@ -414,7 +414,7 @@ Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, result_parts.push_back(updated_push_index); *result = xla::Tuple(b, result_parts); - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, @@ -463,7 +463,7 @@ Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, *element_result = element_result_parts[0]; } - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, @@ -499,7 +499,7 @@ Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, result_parts.push_back(updated_list_part); result_parts.push_back(xla::GetTupleElement(list, 1)); *result = xla::Tuple(b, result_parts); - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, @@ -541,7 +541,7 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, } slice_shape.erase(slice_shape.begin()); *result = xla::Reshape(read, slice_shape); - return OkStatus(); + return absl::OkStatus(); } Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, @@ -558,7 +558,7 @@ Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, std::vector result_parts{tensor, xla::ConstantR0(b, push_index)}; *result = xla::Tuple(b, result_parts); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc index 015a0ce40e80d3..e7ba8d13082849 100644 --- a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc @@ -42,7 +42,7 @@ class ToBoolOp : public XlaOpKernel { if (shape.rank() == 0) { auto result = xla::Ne(ctx->Input(0), xla::ZerosLike(input)); ctx->SetOutput(0, result); - return OkStatus(); + return absl::OkStatus(); } // Otherwise, any input tensor with elements returns True. Input tensor @@ -54,7 +54,7 @@ class ToBoolOp : public XlaOpKernel { auto result = xla::Ne(num_elements, xla::ZerosLike(num_elements)); ctx->SetOutput(0, result); - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 7cdf30594581c6..27cfc014bacb04 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -41,7 +41,7 @@ Status ValidateAssignUpdateVariableOpShapes(XlaOpKernelContext* ctx) { ctx->GetVariableTypeAndShape(0, &variable_dtype, &variable_shape)); TF_RETURN_IF_ERROR( ValidateAssignUpdateVariableOpShapes(variable_shape, value_shape)); - return OkStatus(); + return absl::OkStatus(); } class VarIsInitializedOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index c3370dcde64b70..d6685cc1e1d965 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -77,7 +77,7 @@ Status VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext* ctx, } } } - return OkStatus(); + return absl::OkStatus(); } // Builds XlaCompiler argument descriptions `args` from `ctx`. @@ -128,7 +128,7 @@ Status MakeXlaCompilerArgumentsFromInputs( } } } - return OkStatus(); + return absl::OkStatus(); } // Populates loop invariant indices to true in `loop_invariants`. @@ -186,7 +186,7 @@ Status ConvertLoopInvariantsToConst( compile_time_const_arg_indices->at(arg_idx) = true; (*num_compile_time_const_args)++; } - return OkStatus(); + return absl::OkStatus(); } Status VerifyBodyInputAndOutputShapeMatch( @@ -213,7 +213,7 @@ Status VerifyBodyInputAndOutputShapeMatch( xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(body_output_shape)); } - return OkStatus(); + return absl::OkStatus(); } StatusOr BuildWrappedCond( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 589b9daec8772e..6a31628501cece 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -124,7 +124,7 @@ tsl::StatusOr> XlaCallModuleLoader::Create( return loader; } -tsl::Status XlaCallModuleLoader::SetPlatformIndex( +absl::Status XlaCallModuleLoader::SetPlatformIndex( absl::string_view compilation_platform) { int platform_index = -1; if (!platforms_.empty()) { @@ -186,7 +186,7 @@ tsl::Status XlaCallModuleLoader::SetPlatformIndex( return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::RefineDynamicShapes( +absl::Status XlaCallModuleLoader::RefineDynamicShapes( llvm::ArrayRef input_shapes) { // Skip shape refinement for new versions if USES_SHAPE_POLYMORPHISM_ATTR=1 if (version_ >= kVersionStartSupportUsesShapePolymorphismAttr) { @@ -264,7 +264,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder)); mlir::RankedTensorType type = mlir::RankedTensorType::get(xla_dimensions, element_type); - // TODO(burmako): This fails with an obscure compilation error. + // TODO(burmako): This fails with an obscure compilation error on Windows. // TF_ASSIGN_OR_RETURN( // mlir::Type type, // ConvertShapeToType(xla_shape, builder)); @@ -349,7 +349,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::LoadModule( +absl::Status XlaCallModuleLoader::LoadModule( mlir::MLIRContext *context, int version, std::string module_str, std::vector disabled_checks, std::vector platforms, int num_invocation_args, @@ -446,7 +446,7 @@ tsl::Status XlaCallModuleLoader::LoadModule( return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::ValidateDialect() { +absl::Status XlaCallModuleLoader::ValidateDialect() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); bool moduleHasUnsupportedDialects = false; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h index e77ce0effcf92c..3e9627ebcc29a2 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -52,7 +52,7 @@ class XlaCallModuleLoader { // Sets the platform index argument, if the module is compiled for multiple // platforms, and then erases the argument. - tsl::Status SetPlatformIndex(absl::string_view compilation_platform); + absl::Status SetPlatformIndex(absl::string_view compilation_platform); // Refines the dynamic module arguments based on the static argument shapes. // This assumes that the module has a "main" function without dimension args, @@ -71,10 +71,10 @@ class XlaCallModuleLoader { // cause lifetime issues. // The input_shapes includes only the non-token and the non-platform-index // arguments. - tsl::Status RefineDynamicShapes(llvm::ArrayRef input_shapes); + absl::Status RefineDynamicShapes(llvm::ArrayRef input_shapes); // Validates that the module only contains ops from valid dialects. - tsl::Status ValidateDialect(); + absl::Status ValidateDialect(); // Validates that the module represents a statically-shaped StableHLO program, // otherwise all sorts of weirdness might happen in the HLO exporter which is @@ -97,16 +97,16 @@ class XlaCallModuleLoader { XlaCallModuleLoader() = default; // Initializes the loader with the given serialized module string. - tsl::Status LoadModule(mlir::MLIRContext* context, int version, - std::string module_str, - std::vector disabled_checks, - std::vector platforms, - int num_invocation_args, - bool main_has_token_input_output); + absl::Status LoadModule(mlir::MLIRContext* context, int version, + std::string module_str, + std::vector disabled_checks, + std::vector platforms, + int num_invocation_args, + bool main_has_token_input_output); // Adds a wrapper for the "main" function to compute the platform index and // the dimension arguments. - tsl::Status AddMainWrapper(); + absl::Status AddMainWrapper(); mlir::MLIRContext* context_; int version_; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc index cb36059f62051a..2f1883a289cd3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc @@ -93,7 +93,7 @@ class XlaCustomCallV2Op : public XlaOpKernel { } } - return OkStatus(); + return absl::OkStatus(); } std::string call_target_name_; diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index eb4526a7ea3cc4..fb53e079b97857 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -74,7 +74,7 @@ Status RewriteLayoutWithShardedShape( layout_preference)); *xla_shape->mutable_layout() = per_device_xla_shape.layout(); } - return OkStatus(); + return absl::OkStatus(); } // There is a shape_representation_fn or sharding for an output, this function diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index ef70161fdcdcb2..da1b4182004e8a 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -51,7 +51,7 @@ Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) { TF_ASSIGN_OR_RETURN(*lhs, xla::BroadcastTo(*lhs, bcast.output_shape())); TF_ASSIGN_OR_RETURN(*rhs, xla::BroadcastTo(*rhs, bcast.output_shape())); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 1ea2e98903d69e..c57fb68edf8e24 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h" +#include #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -58,14 +60,6 @@ auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2", "Tracks usage of the MLIR-based TF2XLA bridge among TF2 models"); -auto* replicated_graphs_without_device_type_counter = - tensorflow::monitoring::Counter<1>::New( - /* metric name */ - "/tensorflow/core/tf2xla/replicated_graphs_without_device_type_count", - /* metric description */ - "Tracks if any replicated graphs are without device type", - /* metric field */ "version"); - namespace { using ::mlir::ModuleOp; @@ -79,22 +73,6 @@ bool HasTPUDevice(const DeviceSet& device_set) { return false; } -// Check that graph has tf.StatefulPartitionedCall op with _XlaMustCompile. -bool RunNonReplicatedBridge(const Graph& graph) { - const std::string kStatefulPartitionedCallOp = "StatefulPartitionedCall"; - const std::string kXlaMustCompile = "_XlaMustCompile"; - for (const Node* node : graph.nodes()) { - auto node_op = node->type_string(); - if (node_op == kStatefulPartitionedCallOp) { - auto attr = node->attrs().FindByString(kXlaMustCompile); - if (attr != nullptr && attr->b() == true) { - return true; - } - } - } - return false; -} - bool HasTPUDevice(mlir::ModuleOp module) { mlir::TF::RuntimeDevices devices; if (failed(GetDevicesFromOp(module.getOperation(), &devices))) return false; @@ -105,49 +83,10 @@ bool HasTPUDevice(mlir::ModuleOp module) { }); } -bool IsReplicatedGraph(mlir::ModuleOp module) { - auto walk_result = module.walk([&](mlir::Operation* op) { - // TODO(b/223677572): Once the scope for new compilation and replication - // markers is expanded beyond bridge we can remove this check for - // `kTPUReplicateAttr`, we will then always have a `kCompileDeviceTypeAttr` - // in such cases (see above). - // TODO(b/229028654): Remove string conversion once we have C++17. - const llvm::StringRef tpu_replicate_attr_name(kTpuReplicateAttr.data(), - kTpuReplicateAttr.size()); - auto replicate_attr = - op->getAttrOfType(tpu_replicate_attr_name); - if (replicate_attr) return mlir::WalkResult::interrupt(); - return mlir::WalkResult::advance(); - }); - return walk_result.wasInterrupted(); -} - -bool IsReplicatedGraphWithoutDeviceType(mlir::ModuleOp module) { - return !HasTPUDevice(module) && IsReplicatedGraph(module); -} - -bool IsSingleCoreTPUGraph(mlir::ModuleOp module) { - auto walk_result = module.walk([&](mlir::Operation* op) { - // Check for ops with compile device type "TPU". This allows us to support - // TPU compilation without replication. Note that currently the compile - // device type is not set by default before bridge, only if eager context - // attribute `jit_compile_rewrite` is true. - // TODO(b/229028654): Remove string conversion once we have C++17. - const llvm::StringRef compile_device_type_attr_name( - kCompileDeviceTypeAttr.data(), kCompileDeviceTypeAttr.size()); - auto compilation_attr = - op->getAttrOfType(compile_device_type_attr_name); - if (compilation_attr && compilation_attr.getValue().str() == kTpuDevice) { - return mlir::WalkResult::interrupt(); - } - return mlir::WalkResult::advance(); - }); - return walk_result.wasInterrupted(); -} - -bool RunReplicatedBridge(mlir::ModuleOp module) { - if (HasTPUDevice(module) && IsReplicatedGraph(module)) return true; - return IsSingleCoreTPUGraph(module); +bool HasDevice(mlir::ModuleOp module) { + mlir::TF::RuntimeDevices devices; + if (failed(GetDevicesFromOp(module.getOperation(), &devices))) return false; + return !devices.device_names().empty(); } bool HasTPUPartitionedCallOpInModule(mlir::ModuleOp module) { @@ -201,10 +140,12 @@ absl::Status RunLowerToRuntimeOpsOnSubmodule(ModuleOp parent_module, // The config_proto param is a required input for all TF1 graphs but it is // redundant for TF2 graphs. MlirOptimizationPassState GetPassStateImpl( - bool run_replicated_bridge, const ConfigProto& config_proto, + bool is_supported_by_replicated_brige, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library) { // Skip MLIR TF/XLA Bridge if no XLA-compilable ops are found. - if (!run_replicated_bridge && !RunNonReplicatedBridge(graph)) { + // TODO(b/324474356): also check the called function in the library. + if (!is_supported_by_replicated_brige && + !IsSupportedByNonReplicatedBridge(graph, /*function_library*/ nullptr)) { VLOG(3) << "Skipping MLIR Bridge, graph is not qualified to run the bridge"; return MlirOptimizationPassState::Disabled; } @@ -214,58 +155,43 @@ MlirOptimizationPassState GetPassStateImpl( // GetMlirBridgeRolloutPolicy will analyze a TPU graph if users have not // explicltly requested a policy. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( - graph, &function_library, config_proto, - /*run_replicated_bridge*/ run_replicated_bridge, + graph, &function_library, config_proto, is_supported_by_replicated_brige, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/false, /*record_stats=*/false); // GetPassState is called once before MlirBridgePass starts, and the pass // gets skipped if it is disabled. Log such cases in this function. The cases // where the pass is enabled will only be logged during their execution to // prevent them from being counted twice. - if (run_replicated_bridge) { - switch (policy) { - case MlirBridgeRolloutPolicy::kEnabledByUser: - return MlirOptimizationPassState::Enabled; - case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: - return MlirOptimizationPassState::FallbackEnabled; - case MlirBridgeRolloutPolicy::kDisabledByUser: - VLOG(1) << "Skipping MLIR TPU Bridge, disabled by user. " - "Old bridge will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, - "disabled_by_user"); - return MlirOptimizationPassState::Disabled; - case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: - VLOG(1) << "Skipping MLIR TPU Bridge, disabled because " - "graph has unsupported features. Old bridge will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, - "invalid_graph"); - // We set `uses_uninitialized_resource_args` to false here because the - // first phase of the bridge is not affected by uninitialized resource - // args. - // For Invalid Graph Analysis we need to log here because Run will not - // be called. - LogGraphFeatures(graph, &function_library, config_proto, - /*uses_uninitialized_resource_args=*/false, - /*is_v1_compat=*/false); - return MlirOptimizationPassState::Disabled; - } - } - // TODO(b/277112519): Have uniform behavior for GPU/CPU and TPU switch (policy) { case MlirBridgeRolloutPolicy::kEnabledByUser: return MlirOptimizationPassState::Enabled; case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: return MlirOptimizationPassState::FallbackEnabled; - case MlirBridgeRolloutPolicy::kDisabledByUser: - VLOG(1) << "Skipping MLIR CPU/GPU Bridge, disabled by user."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, - "disabled_by_user"); + case MlirBridgeRolloutPolicy::kDisabledByUser: { + VLOG(1) << "Skipping MLIR " + << (is_supported_by_replicated_brige ? "Replicated" + : "Non-Replicated") + << " Bridge, disabled by user. " + "The fallback will evaluate."; + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + is_supported_by_replicated_brige ? "tpu" : "cpu/gpu", "v2", true, + "disabled_by_user"); return MlirOptimizationPassState::Disabled; - default: - // This case should never be hit. Added here to be consistent with OSS - // implementation. - metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "v2", false, + } + case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: + // Graph analysis only runs on TPU graph. + VLOG(1) << "Skipping MLIR TPU Bridge, disabled because the " + "graph has unsupported features. The fallback will evaluate."; + metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, "invalid_graph"); + // We set `uses_uninitialized_resource_args` to false here because the + // first phase of the bridge is not affected by uninitialized resource + // args. + // For Invalid Graph Analysis we need to log here because Run will not + // be called. + LogGraphFeatures(graph, &function_library, config_proto, + /*uses_uninitialized_resource_args=*/false, + /*is_v1_compat=*/false); return MlirOptimizationPassState::Disabled; } } @@ -274,14 +200,18 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( const DeviceSet* device_set, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library) const { + // While we do not use device type information to choose which pass pipeline + // to execute, it's needed for successful execution. if (!device_set) { // This is not expected in practice. VLOG(1) << "Device set is empty!"; return MlirOptimizationPassState::Disabled; } - return GetPassStateImpl(/*run_replicated_bridge*/ HasTPUDevice(*device_set), - config_proto, graph, function_library); + return GetPassStateImpl( + /*is_supported_by_replicated_brige*/ IsSupportedByReplicatedBridge( + graph, &function_library), + config_proto, graph, function_library); } // This runs the first phase of the "bridge", transforming the graph in a form @@ -297,18 +227,9 @@ Status MlirBridgePass::Run(const std::string& function_name, static absl::once_flag flag; absl::call_once(flag, UpdateLogVerbosityIfDefined, "TF_DEBUG_LOG_VERBOSITY"); - // Check if it's possible for a replicated graph to not have a device type. - if (IsReplicatedGraphWithoutDeviceType(module)) { - replicated_graphs_without_device_type_counter->GetCell("v2")->IncrementBy( - 1); - } - - // Check if the graph has any XLA-compilable ops. - // This check needs to precede GetPassState for instrumentation purposes. - bool run_replicated_bridge = RunReplicatedBridge(module); - if (!run_replicated_bridge && !RunNonReplicatedBridge(graph)) { - VLOG(1) << "Skipping MLIR TF2XLA Bridge, no XLA-compilable ops found."; - return OkStatus(); + if (!HasDevice(module)) { + LOG(INFO) << "No devices in " << function_name << "\n"; + return absl::OkStatus(); } if (HasTPUPartitionedCallOpInModule(module)) { @@ -320,8 +241,9 @@ Status MlirBridgePass::Run(const std::string& function_name, // TODO(b/241853328): Add caching of pass state and call logging/metrics // related to graph analysis from here. - auto pass_state = GetPassStateImpl(run_replicated_bridge, config_proto, graph, - function_library); + bool is_supported_by_replicated_brige = IsSupportedByReplicatedBridge(module); + auto pass_state = GetPassStateImpl(is_supported_by_replicated_brige, + config_proto, graph, function_library); if (pass_state == MlirOptimizationPassState::Disabled) { // GetPassState is called before run() and run() will only be called if the @@ -333,7 +255,7 @@ Status MlirBridgePass::Run(const std::string& function_name, } bool fallback_enabled = false; - if (run_replicated_bridge) { + if (is_supported_by_replicated_brige) { if (pass_state == MlirOptimizationPassState::FallbackEnabled) { // We set `uses_uninitialized_resource_args` to false here because the // first phase of the bridge is not affected by uninitialized resource @@ -350,7 +272,7 @@ Status MlirBridgePass::Run(const std::string& function_name, TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_TPU_JIT, fallback_enabled, + module, /*is_supported_by_replicated_brige*/ true, fallback_enabled, function_name)); TF_RETURN_IF_ERROR( @@ -360,8 +282,8 @@ Status MlirBridgePass::Run(const std::string& function_name, VLOG(1) << "Running GPU/CPU Bridge"; TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( - module, tf2xla::v2::DeviceType::XLA_GPU_JIT, fallback_enabled, - function_name)); + module, /*is_supported_by_replicated_brige*/ false, + fallback_enabled, function_name)); TF_RETURN_IF_ERROR( tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline( @@ -376,14 +298,14 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( const DeviceSet* device_set, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library) const { - // Skip MLIR TPU Bridge if no TPU devices found. - if (device_set && !HasTPUDevice(*device_set)) + // Skip MLIR Bridge if no potential XLA clusters are found. + if (!IsSupportedByReplicatedBridge(graph, &function_library)) return MlirOptimizationPassState::Disabled; // We set `uses_uninitialized_resource_args` to false here because the first // phase of the bridge is not affected by uninitialized resource args. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( graph, /*function_library=*/&function_library, config_proto, - /*run_replicated_bridge*/ true, + /*is_supported_by_replicated_brige*/ true, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/true, /*record_stats=*/false); switch (policy) { @@ -423,14 +345,8 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, // Skip function graphs as MlirBridgePass will be used instead. if (options.is_function_graph) return OkStatus(); - // Check if it's possible for a replicated graph to not have a device type. - if (IsReplicatedGraphWithoutDeviceType(module)) { - replicated_graphs_without_device_type_counter->GetCell("v1")->IncrementBy( - 1); - } - // Skip MLIR TPU Bridge if no TPU devices or TPU ops found. - if (!RunReplicatedBridge(module)) { + if (!IsSupportedByReplicatedBridge(module)) { VLOG(1) << "Skipping MLIR TPU Bridge V1 Compat, no TPU devices or TPU ops " "found"; return OkStatus(); diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index edb2a40f4d332b..27a534296921cd 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -42,7 +42,7 @@ Status UnchangedRank(shape_inference::InferenceContext* c) { } else { c->set_output(0, c->input(0)); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("XlaBroadcastHelper") @@ -294,7 +294,7 @@ static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) { } c->set_output(0, c->MakeShape(output_dims)); - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("XlaDot") @@ -398,7 +398,7 @@ REGISTER_OP("XlaDynamicSlice") return UnchangedRank(c); } c->set_output(0, size_indices_value); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA DynamicSlice operator, documented at @@ -556,7 +556,7 @@ REGISTER_OP("XlaPad") } c->set_output(0, c->MakeShape(output_dims)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Pad operator, documented at @@ -587,7 +587,7 @@ REGISTER_OP("XlaRecv") shape_inference::ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Receives the named tensor from another XLA computation. Wraps the XLA Recv @@ -630,7 +630,7 @@ REGISTER_OP("XlaReduce") } else { c->set_output(0, c->input(0)); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Reduce operator, documented at @@ -684,7 +684,7 @@ REGISTER_OP("XlaVariadicReduce") c->set_output(i, c->input(i)); } } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the variadic XLA Reduce operator. @@ -768,7 +768,7 @@ REGISTER_OP("XlaVariadicReduceV2") for (int i = 0; i < nr_inputs; ++i) { c->set_output(i, output_shape); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the variadic XLA Reduce operator. @@ -828,7 +828,7 @@ REGISTER_OP("XlaRngBitGenerator") shape_inference::ShapeHandle output; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output)); c->set_output(1, output); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Stateless PRNG bit generator. @@ -912,7 +912,7 @@ REGISTER_OP("XlaKeyValueSort") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); c->set_output(1, c->input(1)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Sort operator, documented at @@ -938,7 +938,7 @@ REGISTER_OP("XlaVariadicSort") std::vector input_shapes; TF_RETURN_IF_ERROR(c->input("inputs", &input_shapes)); TF_RETURN_IF_ERROR(c->set_output("outputs", input_shapes)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA Sort operator, documented at @@ -1066,7 +1066,7 @@ REGISTER_OP("XlaSpmdFullToShardShape") dims.push_back(c->MakeDim(dim)); } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( An op used by XLA SPMD partitioner to switch from automatic partitioning to @@ -1092,7 +1092,7 @@ REGISTER_OP("XlaSpmdShardToFullShape") shape_inference::ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( An op used by XLA SPMD partitioner to switch from manual partitioning to @@ -1119,7 +1119,7 @@ REGISTER_OP("XlaReplicaId") .Output("id: int32") .SetShapeFn([](shape_inference::InferenceContext* context) { context->set_output(0, context->MakeShape({})); - return OkStatus(); + return absl::OkStatus(); }) .Doc("Replica ID."); @@ -1212,7 +1212,7 @@ Status OptimizationBarrierShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_inputs(); ++i) { c->set_output(i, c->input(i)); } - return OkStatus(); + return absl::OkStatus(); } REGISTER_OP("XlaOptimizationBarrier") @@ -1258,7 +1258,7 @@ REGISTER_OP("XlaCustomCall") shape_inference::ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Wraps the XLA CustomCall operator @@ -1293,7 +1293,7 @@ REGISTER_OP("XlaCustomCallV2") TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shapes[i], &shape)); c->set_output(i, shape); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Emits an HLO `CustomCall` operation with multiple outputs. @@ -1346,7 +1346,7 @@ REGISTER_OP("XlaCallModule") << "] : " << c->DebugString(s); c->set_output(i, s); } - return OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Invokes a StableHLO module. @@ -1365,7 +1365,7 @@ version: Tracks changes the semantics of the op, to support backwards the op carries a StableHLO module with compatibility guarantees. From version 5, XLACallModule can include `stablehlo.custom_call` op to execute tf functions. From version 6 the op supports the `disabled_checks` attribute. - See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code. + See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code. module: A serialized computation, a text or bytecode representation of an mlir.Module. The return type must be a tuple if and only if the `Sout` is a list with 0 or more than 1 elements. The length of `Tout` and diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 27940b7fb92c17..5846013128611c 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -667,7 +667,7 @@ def call_module_maximum_supported_version(): """Maximum version of XlaCallModule op supported. See versioning details documentation for the XlaCallModule op at: - https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code + https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code """ return 9 diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index b29d73adffb0a4..a081fa18891ba2 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -66,7 +66,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, if (first_resource_index == -1) { // No resource input. No need to rewrite. *need_rewrite = false; - return OkStatus(); + return absl::OkStatus(); } *need_rewrite = false; @@ -77,7 +77,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, } } if (!*need_rewrite) { - return OkStatus(); + return absl::OkStatus(); } *resource_input_count = 0; @@ -100,7 +100,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, } } - return OkStatus(); + return absl::OkStatus(); } // Given mapping between original input index and rearranged input index, @@ -122,7 +122,7 @@ Status ReorderInputEdges(Graph* g, Node* n, g->RemoveEdge(e); g->AddEdge(src, src_output, n, new_dst_input)->DebugString(); } - return OkStatus(); + return absl::OkStatus(); } // For While node, given mapping between original input index and rearranged @@ -154,7 +154,7 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count, g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input); } } - return OkStatus(); + return absl::OkStatus(); } // Given mapping between original input index and rearranged input index, change @@ -203,7 +203,7 @@ Status CalculateRetvalRearrange( TF_RETURN_IF_ERROR(GetNodeAttr(arg->def(), "index", &src_index)); resource_retval_to_arg->insert(std::make_pair(i, src_index)); } - return OkStatus(); + return absl::OkStatus(); } // Given original output types and return value index mapping, return the new @@ -252,7 +252,7 @@ Status RearrangeOutputEdges(Node* n, Graph* g, g->AddEdge(n, iter->second, dst, dst_input); } } - return OkStatus(); + return absl::OkStatus(); } // Given mapping between original output index and rearranged output index, @@ -287,7 +287,7 @@ Status MaybeRewriteWhileNode( types, &input_need_rearrange, &resource_input_count, &index_mapping)); if (!input_need_rearrange) { *node_rewritten = false; - return OkStatus(); + return absl::OkStatus(); } *node_rewritten = true; @@ -379,7 +379,7 @@ Status MaybeRewriteWhileNode( n->ClearAttr(attr_name); n->AddAttr(attr_name, attr_value); } - return OkStatus(); + return absl::OkStatus(); } Status MaybeRewriteIfNode( @@ -403,7 +403,7 @@ Status MaybeRewriteIfNode( DT_RESOURCE) != out_types.end(); if (!input_need_rearrange && !has_resource_output) { *node_rewritten = false; - return OkStatus(); + return absl::OkStatus(); } *node_rewritten = true; @@ -514,7 +514,7 @@ Status MaybeRewriteIfNode( n->ClearAttr("Tout"); n->AddAttr("Tout", new_out_types); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -557,7 +557,7 @@ Status RearrangeFunctionArguments( } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_util.cc b/tensorflow/compiler/tf2xla/resource_util.cc index e91ce07e6d6983..4180b8f1330bcd 100644 --- a/tensorflow/compiler/tf2xla/resource_util.cc +++ b/tensorflow/compiler/tf2xla/resource_util.cc @@ -104,7 +104,7 @@ Status PropagateFromArgOp( int index; TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", &index)); - if (!resource_arg_indices.contains(index)) return OkStatus(); + if (!resource_arg_indices.contains(index)) return absl::OkStatus(); TF_RET_CHECK(function_name.has_value()) << "ResourceUsageAnalysis does not support analyzing _Arg nodes " @@ -122,7 +122,7 @@ Status PropagateFromArgOp( (*user_to_source)[o] = src_node_info; } - return OkStatus(); + return absl::OkStatus(); } Status UpdateResourceUsageFromFunctionBodyAnalysis( @@ -176,7 +176,7 @@ Status UpdateResourceUsageFromFunctionBodyAnalysis( } } - return OkStatus(); + return absl::OkStatus(); } Status PropagateThroughCallOp( @@ -219,7 +219,7 @@ Status PropagateThroughCallOp( TF_RETURN_IF_ERROR(UpdateResourceUsageFromFunctionBodyAnalysis( n, function_name, *fbody, called_function_source_to_path, user_to_source, source_to_path)); - return OkStatus(); + return absl::OkStatus(); } // Analyzes pass through values for Identity and IdentityN ops. @@ -246,7 +246,7 @@ Status PropagateThroughIdentityOp( } } - return OkStatus(); + return absl::OkStatus(); } Status AnalyzeResourceUsage( @@ -313,7 +313,7 @@ Status AnalyzeResourceUsage( it.first->dst()->type_string()); } - return OkStatus(); + return absl::OkStatus(); } } // anonymous namespace diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index b9093c2105cd0a..e01d1b919f9699 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -45,7 +45,7 @@ Status PopulateInfeedLayoutVector(const xla::Shape& shape, } else { layouts->insert(layouts->end(), shape.rank(), -1); } - return OkStatus(); + return absl::OkStatus(); } // Populate the output layout unless the minor_to_major array contains all -1 @@ -83,7 +83,7 @@ Status AssignLayout( layout = layout_func(*shape); } *shape->mutable_layout() = layout; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -100,7 +100,7 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, for (int i = 0; i < shape.rank(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } - return OkStatus(); + return absl::OkStatus(); } // Convert a TensorShape into the equivalent XLA Shape proto. @@ -110,7 +110,7 @@ Status TensorShapeToXLAShape(DataType dtype, xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); *shape = TensorShapeToXLAShape(type, tensor_shape); - return OkStatus(); + return absl::OkStatus(); } Status TensorShapeToBoundedXLAShape(DataType dtype, @@ -122,7 +122,7 @@ Status TensorShapeToBoundedXLAShape(DataType dtype, if (tensor_shape.unknown_rank()) { // For unknown shape, create a rank 1 size 0 tensor. *shape = xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); - return OkStatus(); + return absl::OkStatus(); } if (tensor_shape.dims() != bound.dims()) { @@ -157,7 +157,7 @@ Status TensorShapeToBoundedXLAShape(DataType dtype, } } *shape = result; - return OkStatus(); + return absl::OkStatus(); } xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, @@ -190,7 +190,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); *shape = TensorShapeToXLAShape(type, tensor_shape); - return OkStatus(); + return absl::OkStatus(); } StatusOr TensorShapeToXLAShape(DataType dtype, @@ -272,7 +272,7 @@ Status GetShapeWithLayout( VLOG(4) << "Shape[] = " << xla::ShapeUtil::HumanStringWithLayout(*output_shape); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index 1a62857c537cc1..9446e4b4adadb9 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -74,7 +74,7 @@ Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { return errors::Internal("Unknown node type to set 'device_ordinal': ", node->DebugString()); } - return OkStatus(); + return absl::OkStatus(); } std::set CalculateTokenInputsForOutputToken(const Graph& g) { @@ -143,7 +143,7 @@ Status ParseHostComputeCoreList(absl::Span list_from_attr, } (*host_compute_core)[parts[0]] = core; } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 0fbf983058ca32..3fb8523ce71e0c 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -37,7 +37,7 @@ Status InstantiateFunctionForTest(const string& name, for (NodeDef& n : inst.nodes) { *result->gdef.add_node() = std::move(n); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 61768b0ff5557e..ef87b320cdcd0c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -125,7 +125,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, ++input_index; } } - return OkStatus(); + return absl::OkStatus(); } Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { @@ -141,7 +141,7 @@ Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { node.mutable_attr()->erase("allowed_devices"); } } - return OkStatus(); + return absl::OkStatus(); }; for (auto& node : *graph_def->mutable_node()) { TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); @@ -151,7 +151,7 @@ Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -164,7 +164,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR( ConvertGraphToXla(std::move(graph), config, client, computation)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index a43ad91c20d828..ad5de83c814d4c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -55,7 +55,7 @@ Status ValidateTensorId(const tf2xla::TensorId& id) { if (id.output_index() < 0) { return errors::InvalidArgument("TensorId output_index must be positive"); } - return OkStatus(); + return absl::OkStatus(); } Status CheckNameDuplicates(const string& kind, const string& name, @@ -65,7 +65,7 @@ Status CheckNameDuplicates(const string& kind, const string& name, return errors::InvalidArgument("duplicate ", kind, " name: ", name); } } - return OkStatus(); + return absl::OkStatus(); } Status CheckFeedFetchNameConflicts(const string& kind, @@ -79,7 +79,7 @@ Status CheckFeedFetchNameConflicts(const string& kind, " and ", name_data); } } - return OkStatus(); + return absl::OkStatus(); } // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to @@ -108,7 +108,7 @@ Status CopyAssociatedFunctions(Graph* g, } } } - return OkStatus(); + return absl::OkStatus(); } // Replaces the single edge feeding into {dst,dst_input} with a new @@ -162,7 +162,7 @@ Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output, } } } - return OkStatus(); + return absl::OkStatus(); } // For graph `g`, replaces _Arg nodes whose "index" attribute is in @@ -190,7 +190,7 @@ Status ReplaceArgUsageWithConstNode( TF_RETURN_IF_ERROR( ReplaceSrcOutputUsageWithNode(g, arg_node, 0, const_node)); } - return OkStatus(); + return absl::OkStatus(); } // Replaces the single input to _Retval nodes with an index in the keys of @@ -220,7 +220,7 @@ Status ReplaceRetvalInputWithArg( ReplaceEdge(g, ret_nodes[arg_index], 0, arg_nodes[arg_index], 0) .status()); } - return OkStatus(); + return absl::OkStatus(); } // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites @@ -276,7 +276,7 @@ Status PropagateConstIntoFuncAttr( // Copy associated functions. TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld)); - return OkStatus(); + return absl::OkStatus(); } // For an "If" node in graph `g`, if it has Const node inputs, rewrite its @@ -295,7 +295,7 @@ Status PropagateConstIntoIfNode(Graph* g, Node* if_node, } } if (const_input_index_to_node.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Rewrite "then_branch" and "else_branch" function, replace usage of those @@ -306,7 +306,7 @@ Status PropagateConstIntoIfNode(Graph* g, Node* if_node, if_node, attr_name, const_input_index_to_node, lookup_fld, fld)); } - return OkStatus(); + return absl::OkStatus(); } using GraphCache = absl::flat_hash_map>; @@ -456,7 +456,7 @@ Status PropagateConstIntoAndAroundWhileNode( const_input_index_to_node[i] = input_edge->src(); } if (const_input_index_to_node.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with @@ -473,7 +473,7 @@ Status PropagateConstIntoAndAroundWhileNode( TF_RETURN_IF_ERROR( ReplaceSrcOutputUsageWithNode(g, while_node, it.first, it.second)); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -502,7 +502,7 @@ Status ValidateConfig(const tf2xla::Config& config) { if (config.fetch().empty()) { return errors::InvalidArgument("fetches must be specified"); } - return OkStatus(); + return absl::OkStatus(); } Status AddPlaceholdersForFeeds( @@ -599,7 +599,7 @@ Status AddPlaceholdersForFeeds( } } - return OkStatus(); + return absl::OkStatus(); } Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, @@ -664,7 +664,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, *out->add_node() = node; } } - return OkStatus(); + return absl::OkStatus(); } string TensorIdToString(const tf2xla::TensorId& id) { @@ -695,7 +695,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { n->set_assigned_device_name(matching_node->assigned_device_name()); n->set_requested_device(matching_node->requested_device()); } - return OkStatus(); + return absl::OkStatus(); } void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, @@ -858,7 +858,7 @@ Status RewriteAssociatedFunction( } } - return OkStatus(); + return absl::OkStatus(); } Status CachedFunctionHandles::GetOrInstantiate( @@ -868,12 +868,12 @@ Status CachedFunctionHandles::GetOrInstantiate( auto iter = handles_.find(canonicalized_name); if (iter != handles_.end()) { *handle = iter->second; - return OkStatus(); + return absl::OkStatus(); } TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle)); handles_[canonicalized_name] = *handle; - return OkStatus(); + return absl::OkStatus(); } Status CachedFunctionHandles::ReleaseAllHandles() { @@ -965,7 +965,7 @@ Status PropagateConstIntoFunctionalNodes( } } } - return OkStatus(); + return absl::OkStatus(); } Status PruneUnreachableFunctionsFromGraph(const Graph& g, @@ -979,7 +979,7 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g, TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name)); } } - return OkStatus(); + return absl::OkStatus(); } Status RewriteTensorListWithConstElement(Graph* g, @@ -1116,7 +1116,7 @@ Status RewriteTensorListWithConstElement(Graph* g, bwd_while->ClearAttr("body"); bwd_while->AddAttr("body", bwd_body_attr); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index d04dbd314a4931..d1ea22324c7e8c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -41,7 +41,7 @@ namespace tensorflow { namespace { void ExpectErrorContains(const Status& status, absl::string_view str) { - EXPECT_NE(OkStatus(), status); + EXPECT_NE(absl::OkStatus(), status); EXPECT_TRUE(absl::StrContains(status.message(), str)) << "expected error: " << status.message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 335cfdf37b1605..8221aa28b27b0b 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -28,66 +28,66 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { switch (data_type) { case tensorflow::DT_BOOL: *type = xla::PRED; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT4: *type = xla::S4; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT8: case tensorflow::DT_QINT8: *type = xla::S8; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT16: case tensorflow::DT_QINT16: *type = xla::S16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT32: case tensorflow::DT_QINT32: *type = xla::S32; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_INT64: *type = xla::S64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT4: *type = xla::U4; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT8: case tensorflow::DT_QUINT8: *type = xla::U8; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT16: case tensorflow::DT_QUINT16: *type = xla::U16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT32: *type = xla::U32; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_UINT64: *type = xla::U64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT8_E5M2: *type = xla::F8E5M2; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT8_E4M3FN: *type = xla::F8E4M3FN; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_BFLOAT16: *type = xla::BF16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_HALF: *type = xla::F16; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_FLOAT: *type = xla::F32; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_DOUBLE: *type = xla::F64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_COMPLEX64: *type = xla::C64; - return OkStatus(); + return absl::OkStatus(); case tensorflow::DT_COMPLEX128: *type = xla::C128; - return OkStatus(); + return absl::OkStatus(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType: '", diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 240e4bb1a78ceb..bbdf5c7d2c74fa 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -140,7 +140,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, VLOG(4) << "Done"; } -Status XlaCompilationDevice::Sync() { return OkStatus(); } +Status XlaCompilationDevice::Sync() { return absl::OkStatus(); } Status XlaCompilationDevice::MakeTensorFromProto( const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index bb8b29de5b9acf..1ab4f2fd3a9a81 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -75,6 +75,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/util/debug_data_dumper.h" #include "tensorflow/core/util/dump_graph.h" #include "tsl/platform/tensor_float_32_utils.h" @@ -99,7 +100,7 @@ Status CheckSignature(const DataTypeVector& types, " but function parameter has type ", DataTypeString(types[i])); } } - return OkStatus(); + return absl::OkStatus(); } // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for @@ -444,7 +445,7 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape()); *output_shape = program_shape.result(); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -518,7 +519,7 @@ string XlaCompiler::Argument::ShapeHumanString() const { XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), - initialization_status_(OkStatus()), + initialization_status_(absl::OkStatus()), next_step_id_(1), device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_(absl::WrapUnique(device_)) { @@ -571,7 +572,7 @@ static Status GetFunctionBody(const NameAttrList& function, *fbody = flib_runtime->GetFunctionBody(handle); TF_RET_CHECK(*fbody); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::FindFunctionBody(const NameAttrList& function, @@ -598,7 +599,7 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, } VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { @@ -815,7 +816,7 @@ Status XlaCompiler::CompileFunction( auto it = cache_.find({function_id, arg_vector}); if (it != cache_.end()) { *result = it->second; - return OkStatus(); + return absl::OkStatus(); } const FunctionBody* fbody; @@ -927,7 +928,7 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result; - return OkStatus(); + return absl::OkStatus(); } // Computes the XLA shape for argument 'arg'. @@ -975,12 +976,12 @@ Status XlaCompiler::XLAShapeForArgument( arg.type, std::get(arg.shape), xla_shape)); } } - return OkStatus(); + return absl::OkStatus(); } case XlaCompiler::Argument::kTensorList: { TF_RET_CHECK(absl::holds_alternative(arg.shape)); *xla_shape = std::get(arg.shape); - return OkStatus(); + return absl::OkStatus(); } case XlaCompiler::Argument::kConstantResource: case XlaCompiler::Argument::kResource: { @@ -1000,7 +1001,7 @@ Status XlaCompiler::XLAShapeForArgument( TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( arg_sharding, arg.fast_mem, options_.shape_determination_fns, xla_shape)); - return OkStatus(); + return absl::OkStatus(); } case XlaResource::kTensorArray: { if (arg.max_array_size < 0) { @@ -1018,7 +1019,7 @@ Status XlaCompiler::XLAShapeForArgument( arg.tensor_array_gradients.size() + 1, *xla_shape); *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape); } - return OkStatus(); + return absl::OkStatus(); } case XlaResource::kStack: { if (arg.max_array_size < 0) { @@ -1034,7 +1035,7 @@ Status XlaCompiler::XLAShapeForArgument( TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); *xla_shape = xla::ShapeUtil::MakeTupleShape( {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); - return OkStatus(); + return absl::OkStatus(); } case XlaResource::kInvalid: @@ -1044,7 +1045,7 @@ Status XlaCompiler::XLAShapeForArgument( } case XlaCompiler::Argument::kToken: { *xla_shape = xla::ShapeUtil::MakeTokenShape(); - return OkStatus(); + return absl::OkStatus(); } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); @@ -1143,7 +1144,7 @@ Status XlaCompiler::BuildArguments( } if (input_to_args->empty() && !use_tuple_arg) { - return OkStatus(); + return absl::OkStatus(); } // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds @@ -1304,7 +1305,7 @@ Status XlaCompiler::BuildArguments( } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1320,7 +1321,7 @@ Status ValidateFunctionDef(const FunctionDef* fdef, const OpDef* op_def; TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def)); } - return OkStatus(); + return absl::OkStatus(); } // If node is PartitionedCall or StatefulPartitionedCall, returns the @@ -1339,10 +1340,10 @@ Status GetPotentialFunctionName(const Node& node, const string** name) { " does not have 'func' field set"); } *name = &attr_value->func().name(); - return OkStatus(); + return absl::OkStatus(); } *name = &node.type_string(); - return OkStatus(); + return absl::OkStatus(); } // Check that the graph doesn't have any invalid nodes (e.g. incompatible with @@ -1378,7 +1379,7 @@ Status ValidateGraph(const Graph* graph, return errors::InvalidArgument(errmsg); } - return OkStatus(); + return absl::OkStatus(); }; for (const Node* node : graph->nodes()) { @@ -1401,7 +1402,7 @@ Status ValidateGraph(const Graph* graph, s = FindKernelDef(device_type, node->def(), nullptr, nullptr); TF_RETURN_IF_ERROR(maybe_error(node, s)); } - return OkStatus(); + return absl::OkStatus(); } void ConvertConstantsToExpressions(xla::XlaBuilder* builder, @@ -1478,7 +1479,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; - if (VLOG_IS_ON(2)) { + if (VLOG_IS_ON(2) || DEBUG_DATA_DUMPER()->ShouldDump(name, kDebugGroupMain)) { VLOG(2) << "XlaCompiler::CompileGraph: " << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph, flib_runtime_->GetFunctionLibraryDefinition()); @@ -1602,10 +1603,28 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->input_mapping)); for (const auto& [key, send] : host_compute_sends_) { - *result->host_compute_metadata.add_device_to_host() = send; + auto* d2h = result->host_compute_metadata.add_device_to_host(); + *d2h = send; + + for (int i = 0; i < d2h->metadata_size(); ++i) { + const std::string channel_name = + GetDeviceToHostChannelName(d2h->key(), i); + xla::ChannelHandle handle; + TF_RETURN_IF_ERROR(GetDeviceToHostChannelHandle(channel_name, &handle)); + d2h->mutable_metadata(i)->set_channel_id(handle.handle()); + } } for (const auto& [key, recv] : host_compute_recvs_) { - *result->host_compute_metadata.add_host_to_device() = recv; + auto* h2d = result->host_compute_metadata.add_host_to_device(); + *h2d = recv; + + for (int i = 0; i < h2d->metadata_size(); ++i) { + const std::string channel_name = + GetHostToDeviceChannelName(h2d->key(), i); + xla::ChannelHandle handle; + TF_RETURN_IF_ERROR(GetHostToDeviceChannelHandle(channel_name, &handle)); + h2d->mutable_metadata(i)->set_channel_id(handle.handle()); + } } if (!tsl::tensor_float_32_execution_enabled()) { @@ -1617,7 +1636,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape); result->collective_info = context->GetCollectiveInfo(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetChannelHandle(const string& key, @@ -1628,7 +1647,7 @@ Status XlaCompiler::GetChannelHandle(const string& key, } *channel = result.first->second; VLOG(1) << "Channel: " << key << " " << channel->DebugString(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, @@ -1640,7 +1659,7 @@ Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, } *channel = result.first->second; VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, @@ -1652,7 +1671,7 @@ Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, } *channel = result.first->second; VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString(); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1679,7 +1698,7 @@ Status XlaCompiler::SetDeviceToHostMetadata( tf2xla::HostTransferMetadata new_transfer; SetTransfer(key, types, shapes, &new_transfer); if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument( "Duplicate calls to SetDeviceToHostMetadata with key ", key); @@ -1687,7 +1706,7 @@ Status XlaCompiler::SetDeviceToHostMetadata( } tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; SetTransfer(key, types, shapes, &transfer); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetDeviceToHostShapes( @@ -1702,7 +1721,7 @@ Status XlaCompiler::GetDeviceToHostShapes( TensorShape shape(iter->second.metadata(i).shape()); shapes->push_back(shape); } - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::SetHostToDeviceMetadata( @@ -1713,7 +1732,7 @@ Status XlaCompiler::SetHostToDeviceMetadata( tf2xla::HostTransferMetadata new_transfer; SetTransfer(key, types, shapes, &new_transfer); if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument( "Duplicate calls to SetHostToDeviceMetadata with key ", key); @@ -1721,7 +1740,7 @@ Status XlaCompiler::SetHostToDeviceMetadata( } tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; SetTransfer(key, types, shapes, &transfer); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::GetHostComputeControlDependency( @@ -1734,7 +1753,7 @@ Status XlaCompiler::GetHostComputeControlDependency( } else { *handle = iter->second; } - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::SetHostComputeControlDependency( @@ -1746,7 +1765,7 @@ Status XlaCompiler::SetHostComputeControlDependency( host_compute_name); } host_compute_control_output_[host_compute_name] = handle; - return OkStatus(); + return absl::OkStatus(); } void XlaCompiler::PushNodeTokenMapping() { @@ -1760,7 +1779,7 @@ Status XlaCompiler::PopNodeTokenMapping() { "empty."); } node_token_mapping_stack_.pop(); - return OkStatus(); + return absl::OkStatus(); } Status XlaCompiler::SetNodeToken(const string& node_name, const xla::XlaOp op) { @@ -1774,7 +1793,7 @@ Status XlaCompiler::SetNodeToken(const string& node_name, const xla::XlaOp op) { return errors::FailedPrecondition("Token mapping already exists for node ", node_name); } - return OkStatus(); + return absl::OkStatus(); } StatusOr XlaCompiler::GetNodeToken(const string& node_name) { diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 04f12c7ca575d4..ff444efe752640 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -195,7 +195,7 @@ Status XlaContext::RecordCollectiveInfoFromNestedCompilationResult( result.collective_info->group_size) .status(); } - return OkStatus(); + return absl::OkStatus(); } StatusOr XlaContext::RecordCollectiveInfo(int group_key, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index dc40bb47e8f8fe..5f99e7f284e26c 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -86,7 +86,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, *output = input.Clone(); output->mutable_shape_do_not_use()->Swap(&shape); - return OkStatus(); + return absl::OkStatus(); } Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, @@ -110,7 +110,7 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), xla::Broadcast(on_value, output_shape.dim_sizes()), xla::Broadcast(off_value, output_shape.dim_sizes())); - return OkStatus(); + return absl::OkStatus(); } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { @@ -253,7 +253,7 @@ Status ResolveDeviceAssignment( }); run_options.set_device_assignment(&device_assignment); run_options.set_gpu_executable_run_options(&gpu_options); - return OkStatus(); + return absl::OkStatus(); } } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 0bebf471ecfbe9..ee483af794e26d 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -18,6 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ +#include + #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" @@ -33,6 +35,15 @@ namespace tensorflow { using XlaLayoutPreference = mlir::XlaLayoutPreference; +inline std::string GetDeviceToHostChannelName(absl::string_view channel_key, + int index) { + return absl::StrCat(channel_key, "_dtoh_", index); +} +inline std::string GetHostToDeviceChannelName(absl::string_view channel_key, + int index) { + return absl::StrCat(channel_key, "_htod_", index); +} + // Helper methods for building XLA computations. class XlaHelpers { public: diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index ad65c1708794fd..8e8d3f28d8a47a 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/multi_platform_manager.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/test.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -340,8 +340,8 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { string name_; }; - TF_EXPECT_OK(se::MultiPlatformManager::RegisterPlatform( - std::make_unique())); + TF_EXPECT_OK( + se::PlatformManager::RegisterPlatform(std::make_unique())); xla::Compiler::RegisterCompilerFactory(kFakePlatformId, []() { return std::unique_ptr(nullptr); }); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 237d115aa298e5..f70a7df612d149 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -196,7 +196,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( } TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); - return OkStatus(); + return absl::OkStatus(); } // Converts an int16, int32 or int64 scalar literal to an int64. @@ -214,7 +214,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, } else { return errors::InvalidArgument("value must be int16, int32, or int64"); } - return OkStatus(); + return absl::OkStatus(); } // Converts an float32 or float64 scalar literal to a float64. @@ -230,7 +230,7 @@ static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, } else { return errors::InvalidArgument("value must be either float32 or float64"); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputAsIntScalar( @@ -273,7 +273,7 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal, for (int64_t i = 0; i < size; ++i) { out->push_back(literal.Get({i})); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { @@ -288,7 +288,7 @@ Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { // TODO(b/176993339): Support resolving dynamism across computations so // resolving dynamism will not fail in those cases. *out = true; - return OkStatus(); + return absl::OkStatus(); } Tensor dynamism = dynamism_or_status.value(); @@ -302,7 +302,7 @@ Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp)); *out = literal.Get({}); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( @@ -332,7 +332,7 @@ Status XlaOpKernelContext::ResolveInputDynamismReshaped( .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {}) .value(); - return OkStatus(); + return absl::OkStatus(); } Tensor dynamism = dynamism_or_status.value(); @@ -346,7 +346,7 @@ Status XlaOpKernelContext::ResolveInputDynamismReshaped( } TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp)); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( @@ -377,7 +377,7 @@ static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, } else { return errors::InvalidArgument("value must be either int32 or int64"); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputAsIntVector( @@ -424,11 +424,11 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal( for (int64_t i = 0; i < src_data.size(); ++i) { out->data()[i] = src_data[i]; } - return OkStatus(); + return absl::OkStatus(); } case xla::S64: *out = std::move(literal); - return OkStatus(); + return absl::OkStatus(); default: return errors::InvalidArgument( @@ -462,7 +462,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape, ", result: ", num_elements); } *shape = TensorShape(dims); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputAsPartialShape( @@ -478,12 +478,12 @@ Status XlaOpKernelContext::ConstantInputAsPartialShape( "Cannot convert value to PartialTensorShape: ", shape_val); } *shape = PartialTensorShape(); // Shape with unknown rank. - return OkStatus(); + return absl::OkStatus(); } std::vector dims; TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); *shape = PartialTensorShape(dims); - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::InputList(absl::string_view name, @@ -498,7 +498,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } - return OkStatus(); + return absl::OkStatus(); } Status XlaOpKernelContext::ConstantInputList(absl::string_view name, @@ -510,7 +510,7 @@ Status XlaOpKernelContext::ConstantInputList(absl::string_view name, for (int i = start; i < stop; ++i) { TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode)); } - return OkStatus(); + return absl::OkStatus(); } StatusOr XlaOpKernelContext::ConstantInputTensor( @@ -571,7 +571,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, TF_ASSIGN_OR_RETURN(xla::Literal literal, HostTensorToLiteral(*expression->constant_value())); *value = xla::ConstantLiteral(ctx->builder(), literal); - return OkStatus(); + return absl::OkStatus(); } auto shape_determination_fns = ctx->compiler()->options().shape_determination_fns; @@ -590,7 +590,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, } else { *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -625,7 +625,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, } *type = variable->type(); *shape = variable->shape(); - return OkStatus(); + return absl::OkStatus(); } void XlaOpKernelContext::SetOutputExpression(int index, @@ -656,7 +656,7 @@ void XlaOpKernelContext::SetOutputExpression(int index, } XlaExpression::AssignExpressionToTensor(expression, context_->mutable_output(index)); - return OkStatus(); + return absl::OkStatus(); }(); if (!status.ok()) { SetStatus(status); @@ -697,7 +697,7 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { XlaExpression::CastExpressionFromTensor(context_->input(index)); TF_RET_CHECK(expression->resource() != nullptr); *resource = expression->resource(); - return OkStatus(); + return absl::OkStatus(); } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 986388ab32b9f2..0109f6a3f07ef3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -53,7 +53,7 @@ static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { &kernel_class_name)); VLOG(1) << "LaunchOpHasKernelForDevice" << " kernel_class_name: " << kernel_class_name; - return OkStatus(); + return absl::OkStatus(); } XlaOpRegistry::XlaOpRegistry() = default; @@ -437,7 +437,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { compile_time_constant_inputs = CompileTimeConstantInputArgNames(node_def.op()); if (compile_time_constant_inputs->empty()) { - return OkStatus(); + return absl::OkStatus(); } } @@ -470,7 +470,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { } absl::c_sort(*result); - return OkStatus(); + return absl::OkStatus(); } /*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index d48c97d35c31f5..0e1d33a0c1c718 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -109,7 +109,7 @@ Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { } type_ = type; shape_ = shape; - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::SetValue(const xla::XlaOp& value) { @@ -120,7 +120,7 @@ Status XlaResource::SetValue(const xla::XlaOp& value) { } value_ = value; is_overwritten_ = true; - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { @@ -159,7 +159,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { default: LOG(FATAL) << "Invalid resource type"; } - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, @@ -183,7 +183,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, /*tensor_array_multiple_writes_aggregate=*/true)); } *gradient_out = gradient.get(); - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { @@ -198,7 +198,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { } *pack = xla::Tuple(builder, elems); } - return OkStatus(); + return absl::OkStatus(); } Status XlaResource::SetFromPack(const std::set& gradient_sources, @@ -229,7 +229,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, gradient->value_ = v; } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a3538f62e93456..dde60c6079d704 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -336,10 +336,8 @@ cc_library( ":lib_internal", ":protos_all_cc", "//tensorflow/core/kernels:required", - "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform/default/build_config:gtest", ] + tf_additional_test_deps(), ) @@ -499,7 +497,6 @@ cc_library( "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/platform:mutex", - "@local_tsl//tsl/platform/default/build_config:minimal", "@local_tsl//tsl/framework:fixedpoint_types", "//tensorflow/core/platform:types", ], @@ -711,7 +708,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tensorflow_opensource", - "@local_tsl//tsl/platform/default/build_config:tensorflow_platform_specific", + "//tensorflow/core/platform/default/build_config:tensorflow_platform_specific", ], ) @@ -1142,8 +1139,8 @@ cc_library( deps = [ ":portable_tensorflow_lib", "//tensorflow/core/kernels:portable_tensorflow_kernels", + "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform/default/build_config:gtest", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_DatasetFingerprint.pbtxt b/tensorflow/core/api_def/base_api/api_def_DatasetFingerprint.pbtxt new file mode 100644 index 00000000000000..0d2e0ff1793861 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DatasetFingerprint.pbtxt @@ -0,0 +1,20 @@ +op { + graph_op_name: "DatasetFingerprint" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <